code optimization
This commit is contained in:
parent
25d938f3e6
commit
a942da2a5c
@ -23,15 +23,13 @@ void BmmSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<strin
|
||||
break;
|
||||
iutf8string key = phrase_sub.utf8substr(j, phrase_sub.length()-j);
|
||||
if (wordValid(key.stlstring(), appid) == true) {
|
||||
vector<string>::iterator iter = bmm_list.begin();
|
||||
bmm_list.insert(iter, key.stlstring());
|
||||
bmm_list.insert(bmm_list.begin(), key.stlstring());
|
||||
i -= key.length();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j == phrase_sub.length() - 1) {
|
||||
vector<string>::iterator iter = bmm_list.begin();
|
||||
bmm_list.insert(iter, "" + phrase_sub[j]);
|
||||
bmm_list.insert(bmm_list.begin(), "" + phrase_sub[j]);
|
||||
i--;
|
||||
}
|
||||
}
|
||||
|
@ -15,29 +15,13 @@ void DagSegment::ConcreteSplit(iutf8string& sentence, uint32_t appid, vector<str
|
||||
getDag(sentence, appid, dag_map);
|
||||
map<uint32_t, RouteValue> route;
|
||||
calc(sentence, dag_map, route, appid);
|
||||
iutf8string utf8_str(sentence.stlstring());
|
||||
uint32_t N = utf8_str.length();
|
||||
uint32_t N = sentence.length();
|
||||
uint32_t i = 0;
|
||||
string buf = "";
|
||||
while (i < N) {
|
||||
uint32_t j = route[i].idx + 1;
|
||||
string l_word = utf8_str.substr(i, j - i);
|
||||
if (isAllAlphaOrDigit(l_word)) {
|
||||
buf += l_word;
|
||||
i = j;
|
||||
}
|
||||
else {
|
||||
if (!buf.empty()) {
|
||||
vec.push_back(buf);
|
||||
buf = "";
|
||||
}
|
||||
vec.push_back(l_word);
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
if (!buf.empty()) {
|
||||
vec.push_back(buf);
|
||||
buf = "";
|
||||
string l_word = sentence.substr(i, j - i);
|
||||
vec.push_back(l_word);
|
||||
i = j;
|
||||
}
|
||||
|
||||
return;
|
||||
@ -81,12 +65,11 @@ void DagSegment::calc(iutf8string& utf8_str, const map<uint32_t, vector<uint32_t
|
||||
map<uint32_t, WordInfo> wordInfo = word_dict_[word];
|
||||
if (wordInfo.find(0) != wordInfo.end()) {
|
||||
word_info = wordInfo[0];
|
||||
word_freq = word_info.word_freq;
|
||||
}
|
||||
if (wordInfo.find(appid) != wordInfo.end()) {
|
||||
word_info = wordInfo[appid];
|
||||
word_freq = word_info.word_freq;
|
||||
}
|
||||
word_freq = word_info.word_freq;
|
||||
}
|
||||
double route_value = log(word_freq) - logtotal + route[vec[t] + 1].max_route;
|
||||
if (route_value > max_route) {
|
||||
|
@ -5,13 +5,26 @@
|
||||
#include "../utf8_str.h"
|
||||
|
||||
HmmManager::HmmManager(){
|
||||
train_corpus_ = new TrainCorpus();
|
||||
min_emit_ = 1.0;
|
||||
char state[4] = { 'B','M','E','S' };
|
||||
vector<char> state_list(state, state + 4);
|
||||
for (size_t i = 0; i < state_list.size(); i++) {
|
||||
map<char, double> trans_map;
|
||||
trans_dict_.insert(make_pair(state_list[i], trans_map));
|
||||
for (size_t j = 0; j < state_list.size(); j++) {
|
||||
trans_dict_[state_list[i]].insert(make_pair(state_list[j], 0.0));
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < state_list.size(); i++) {
|
||||
map<string, double> emit_map;
|
||||
emit_dict_.insert(make_pair(state_list[i], emit_map));
|
||||
start_dict_.insert(make_pair(state_list[i],0.0));
|
||||
count_dict_.insert(make_pair(state_list[i], 0));
|
||||
}
|
||||
line_num_ = 0;
|
||||
}
|
||||
|
||||
HmmManager::~HmmManager(){
|
||||
if(NULL != train_corpus_){
|
||||
delete train_corpus_;
|
||||
}
|
||||
}
|
||||
|
||||
bool HmmManager::Init(string train_path, const set<string>& punct_set) {
|
||||
@ -61,29 +74,86 @@ bool HmmManager::Init(string train_path, const set<string>& punct_set) {
|
||||
next_dict_[word1][word2] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
line_num_++;
|
||||
vector<string> word_list; // 保存除空格以外的字符
|
||||
iutf8string utf8_str(str);
|
||||
for (int i = 0; i < utf8_str.length(); i++) {
|
||||
if (utf8_str[i] != " ") {
|
||||
word_list.push_back(utf8_str[i]);
|
||||
}
|
||||
}
|
||||
vector<char> line_state;
|
||||
for (size_t i = 0; i < str_vec.size(); i++) {
|
||||
if (str_vec[i] == "") {
|
||||
continue;
|
||||
}
|
||||
iutf8string utf8_str_item(str_vec[i]);
|
||||
vector<char> item_state;
|
||||
getList(utf8_str_item.length(), item_state);
|
||||
line_state.insert(line_state.end(), item_state.begin(), item_state.end());
|
||||
}
|
||||
if (word_list.size() != line_state.size()) {
|
||||
log_error("[line = %s]\n", str.c_str());
|
||||
}
|
||||
else {
|
||||
for (size_t i = 0; i < line_state.size(); i++) {
|
||||
if (i == 0) {
|
||||
start_dict_[line_state[0]] += 1; // 记录句子第一个字的状态,用于计算初始状态概率
|
||||
count_dict_[line_state[0]] ++; // 记录每个状态的出现次数
|
||||
}
|
||||
else {
|
||||
trans_dict_[line_state[i - 1]][line_state[i]] += 1;
|
||||
count_dict_[line_state[i]] ++;
|
||||
if (emit_dict_[line_state[i]].find(word_list[i]) == emit_dict_[line_state[i]].end()) {
|
||||
emit_dict_[line_state[i]].insert(make_pair(word_list[i], 0.0));
|
||||
}
|
||||
else {
|
||||
emit_dict_[line_state[i]][word_list[i]] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
train_infile.close();
|
||||
|
||||
bool ret = train_corpus_->Init(train_path);
|
||||
if (ret == false) {
|
||||
log_error("train_corpus init error.");
|
||||
return ret;
|
||||
}
|
||||
log_info("total training words length is: %u, next_dict count: %d.", train_cnt_, (int)next_dict_.size());
|
||||
map<char, double>::iterator start_iter = start_dict_.begin();
|
||||
for (; start_iter != start_dict_.end(); start_iter++) { // 状态的初始概率
|
||||
start_dict_[start_iter->first] = start_dict_[start_iter->first] * 1.0 / line_num_;
|
||||
}
|
||||
|
||||
map<char, map<char, double> >::iterator trans_iter = trans_dict_.begin();
|
||||
for (; trans_iter != trans_dict_.end(); trans_iter++) { // 状态转移概率
|
||||
map<char, double> trans_map = trans_iter->second;
|
||||
map<char, double>::iterator trans_iter2 = trans_map.begin();
|
||||
for (; trans_iter2 != trans_map.end(); trans_iter2++) {
|
||||
trans_dict_[trans_iter->first][trans_iter2->first] = trans_dict_[trans_iter->first][trans_iter2->first] / count_dict_[trans_iter->first];
|
||||
}
|
||||
}
|
||||
|
||||
map<char, map<string, double> >::iterator emit_iter = emit_dict_.begin();
|
||||
for (; emit_iter != emit_dict_.end(); emit_iter++) { // 发射概率(状态->词语的条件概率)
|
||||
map<string, double> emit_map = emit_iter->second;
|
||||
map<string, double>::iterator emit_iter2 = emit_map.begin();
|
||||
for (; emit_iter2 != emit_map.end(); emit_iter2++) {
|
||||
double emit_value = emit_dict_[emit_iter->first][emit_iter2->first] / count_dict_[emit_iter->first];
|
||||
if (emit_value < min_emit_ && emit_value != 0.0) {
|
||||
min_emit_ = emit_value;
|
||||
}
|
||||
emit_dict_[emit_iter->first][emit_iter2->first] = emit_value;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void HmmManager::HmmSplit(string str, vector<string>& vec){
|
||||
vector<char> pos_list = viterbi(str);
|
||||
vector<char> pos_list;
|
||||
viterbi(str, pos_list);
|
||||
string result;
|
||||
iutf8string utf8_str(str);
|
||||
for (size_t i = 0; i < pos_list.size(); i++) {
|
||||
result += utf8_str[i];
|
||||
if (pos_list[i] == 'E') {
|
||||
std::size_t found = result.find_last_of(" ");
|
||||
string new_word = result.substr(found + 1);
|
||||
}
|
||||
if (pos_list[i] == 'E' || pos_list[i] == 'S') {
|
||||
result += ' ';
|
||||
}
|
||||
@ -95,7 +165,7 @@ void HmmManager::HmmSplit(string str, vector<string>& vec){
|
||||
vec = splitEx(result, " ");
|
||||
}
|
||||
|
||||
vector<char> HmmManager::viterbi(string sentence) {
|
||||
void HmmManager::viterbi(string sentence, vector<char>& output) {
|
||||
iutf8string utf8_str(sentence);
|
||||
vector<map<char, double> > viterbi_vec;
|
||||
map<char, vector<char> > path;
|
||||
@ -103,11 +173,11 @@ vector<char> HmmManager::viterbi(string sentence) {
|
||||
map<char, double> prob_map;
|
||||
for (size_t i = 0; i < sizeof(states); i++) {
|
||||
char y = states[i];
|
||||
double emit_value = train_corpus_->MinEmit();
|
||||
if (train_corpus_->emit_dict[y].find(utf8_str[0]) != train_corpus_->emit_dict[y].end()) {
|
||||
emit_value = train_corpus_->emit_dict[y].at(utf8_str[0]);
|
||||
double emit_value = min_emit_;
|
||||
if (emit_dict_[y].find(utf8_str[0]) != emit_dict_[y].end()) {
|
||||
emit_value = emit_dict_[y].at(utf8_str[0]);
|
||||
}
|
||||
prob_map[y] = train_corpus_->start_dict[y] * emit_value; // 在位置0,以y状态为末尾的状态序列的最大概率
|
||||
prob_map[y] = start_dict_[y] * emit_value; // 在位置0,以y状态为末尾的状态序列的最大概率
|
||||
path[y].push_back(y);
|
||||
}
|
||||
viterbi_vec.push_back(prob_map);
|
||||
@ -120,12 +190,12 @@ vector<char> HmmManager::viterbi(string sentence) {
|
||||
char state = ' ';
|
||||
for (size_t m = 0; m < sizeof(states); m++) {
|
||||
char y0 = states[m]; // 从y0 -> y状态的递归
|
||||
//cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << train_corpus.trans_dict[y0][y] << " " << train_corpus.emit_dict[y].at(utf8_str[j]) << endl;
|
||||
double emit_value = train_corpus_->MinEmit();
|
||||
if (train_corpus_->emit_dict[y].find(utf8_str[j]) != train_corpus_->emit_dict[y].end()) {
|
||||
emit_value = train_corpus_->emit_dict[y].at(utf8_str[j]);
|
||||
//cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << trans_dict_[y0][y] << " " << emit_dict_[y].at(utf8_str[j]) << endl;
|
||||
double emit_value = min_emit_;
|
||||
if (emit_dict_[y].find(utf8_str[j]) != emit_dict_[y].end()) {
|
||||
emit_value = emit_dict_[y].at(utf8_str[j]);
|
||||
}
|
||||
double prob = viterbi_vec[j - 1][y0] * train_corpus_->trans_dict[y0][y] * emit_value;
|
||||
double prob = viterbi_vec[j - 1][y0] * trans_dict_[y0][y] * emit_value;
|
||||
if (prob > max_prob) {
|
||||
max_prob = prob;
|
||||
state = y0;
|
||||
@ -147,7 +217,23 @@ vector<char> HmmManager::viterbi(string sentence) {
|
||||
state = y;
|
||||
}
|
||||
}
|
||||
return path[state];
|
||||
output.assign(path[state].begin(), path[state].end());
|
||||
}
|
||||
|
||||
void HmmManager::getList(uint32_t str_len, vector<char>& output) {
|
||||
if (str_len == 1) {
|
||||
output.push_back('S');
|
||||
}
|
||||
else if (str_len == 2) {
|
||||
output.push_back('B');
|
||||
output.push_back('E');
|
||||
}
|
||||
else {
|
||||
vector<char> M_list(str_len - 2, 'M');
|
||||
output.push_back('B');
|
||||
output.insert(output.end(), M_list.begin(), M_list.end());
|
||||
output.push_back('E');
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t HmmManager::TrainCnt(){
|
||||
|
@ -23,7 +23,7 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "../trainCorpus.h"
|
||||
#include <stdint.h>
|
||||
using namespace std;
|
||||
|
||||
class HmmManager{
|
||||
@ -35,11 +35,18 @@ public:
|
||||
map<string, map<string, int> >& NextDict();
|
||||
uint32_t TrainCnt();
|
||||
private:
|
||||
vector<char> viterbi(string sentence);
|
||||
void viterbi(string sentence, vector<char>& output);
|
||||
void getList(uint32_t str_len, vector<char>& output);
|
||||
private:
|
||||
uint32_t train_cnt_;
|
||||
TrainCorpus* train_corpus_;
|
||||
map<string, map<string, int> > next_dict_;
|
||||
|
||||
map<char, map<char, double> > trans_dict_;
|
||||
map<char, map<string, double> > emit_dict_;
|
||||
map<char, double> start_dict_;
|
||||
map<char, uint32_t> count_dict_;
|
||||
uint32_t line_num_;
|
||||
double min_emit_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -187,16 +187,13 @@ double NgramSegment::calSegProbability(const vector<string>& vec) {
|
||||
}
|
||||
}
|
||||
// 乘以第一个词的概率
|
||||
if ((pos == 0 && vec[pos] != "<BEG>") || (pos == 1 && vec[0] == "<BEG>")) {
|
||||
if ((pos == 0 && vec[0] != "<BEG>") || (pos == 1 && vec[0] == "<BEG>")) {
|
||||
uint32_t word_freq = 0;
|
||||
WordInfo word_info;
|
||||
if (getWordInfo(vec[pos], 0, word_info)) {
|
||||
word_freq = word_info.word_freq;
|
||||
p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
|
||||
}
|
||||
else {
|
||||
p += log(1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
|
||||
}
|
||||
p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,9 +176,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector<string>& res_all, vector
|
||||
hmm_manager_->HmmSplit(buf, vec);
|
||||
new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
|
||||
}
|
||||
else { // 是否有这种情况
|
||||
new_res_all.push_back(buf);
|
||||
}
|
||||
}
|
||||
buf = "";
|
||||
new_res_all.push_back(res_all[i]);
|
||||
@ -195,9 +192,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector<string>& res_all, vector
|
||||
hmm_manager_->HmmSplit(buf, vec);
|
||||
new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
|
||||
}
|
||||
else { // 是否有这种情况
|
||||
new_res_all.push_back(buf);
|
||||
}
|
||||
buf = "";
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user