From a942da2a5c24830479793b96a406d2b4a6098b0e Mon Sep 17 00:00:00 2001 From: shzhulin3 Date: Wed, 23 Jun 2021 14:47:30 +0800 Subject: [PATCH] code optimization --- src/comm/segment/bmm_segment.cc | 6 +- src/comm/segment/dag_segment.cc | 27 ++---- src/comm/segment/hmm_manager.cc | 138 ++++++++++++++++++++++++------ src/comm/segment/hmm_manager.h | 13 ++- src/comm/segment/ngram_segment.cc | 7 +- src/comm/segment/segment.cc | 6 -- 6 files changed, 131 insertions(+), 66 deletions(-) diff --git a/src/comm/segment/bmm_segment.cc b/src/comm/segment/bmm_segment.cc index 5fbdf25..c6f8567 100644 --- a/src/comm/segment/bmm_segment.cc +++ b/src/comm/segment/bmm_segment.cc @@ -23,15 +23,13 @@ void BmmSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector::iterator iter = bmm_list.begin(); - bmm_list.insert(iter, key.stlstring()); + bmm_list.insert(bmm_list.begin(), key.stlstring()); i -= key.length(); break; } } if (j == phrase_sub.length() - 1) { - vector::iterator iter = bmm_list.begin(); - bmm_list.insert(iter, "" + phrase_sub[j]); + bmm_list.insert(bmm_list.begin(), "" + phrase_sub[j]); i--; } } diff --git a/src/comm/segment/dag_segment.cc b/src/comm/segment/dag_segment.cc index 23f4494..aec7dde 100644 --- a/src/comm/segment/dag_segment.cc +++ b/src/comm/segment/dag_segment.cc @@ -15,29 +15,13 @@ void DagSegment::ConcreteSplit(iutf8string& sentence, uint32_t appid, vector route; calc(sentence, dag_map, route, appid); - iutf8string utf8_str(sentence.stlstring()); - uint32_t N = utf8_str.length(); + uint32_t N = sentence.length(); uint32_t i = 0; - string buf = ""; while (i < N) { uint32_t j = route[i].idx + 1; - string l_word = utf8_str.substr(i, j - i); - if (isAllAlphaOrDigit(l_word)) { - buf += l_word; - i = j; - } - else { - if (!buf.empty()) { - vec.push_back(buf); - buf = ""; - } - vec.push_back(l_word); - i = j; - } - } - if (!buf.empty()) { - vec.push_back(buf); - buf = ""; + string l_word = sentence.substr(i, j - i); + vec.push_back(l_word); + i = j; } return; @@ -81,12 +65,11 @@ void DagSegment::calc(iutf8string& utf8_str, const map wordInfo = word_dict_[word]; if (wordInfo.find(0) != wordInfo.end()) { word_info = wordInfo[0]; - word_freq = word_info.word_freq; } if (wordInfo.find(appid) != wordInfo.end()) { word_info = wordInfo[appid]; - word_freq = word_info.word_freq; } + word_freq = word_info.word_freq; } double route_value = log(word_freq) - logtotal + route[vec[t] + 1].max_route; if (route_value > max_route) { diff --git a/src/comm/segment/hmm_manager.cc b/src/comm/segment/hmm_manager.cc index df99218..3415c38 100644 --- a/src/comm/segment/hmm_manager.cc +++ b/src/comm/segment/hmm_manager.cc @@ -5,13 +5,26 @@ #include "../utf8_str.h" HmmManager::HmmManager(){ - train_corpus_ = new TrainCorpus(); + min_emit_ = 1.0; + char state[4] = { 'B','M','E','S' }; + vector state_list(state, state + 4); + for (size_t i = 0; i < state_list.size(); i++) { + map trans_map; + trans_dict_.insert(make_pair(state_list[i], trans_map)); + for (size_t j = 0; j < state_list.size(); j++) { + trans_dict_[state_list[i]].insert(make_pair(state_list[j], 0.0)); + } + } + for (size_t i = 0; i < state_list.size(); i++) { + map emit_map; + emit_dict_.insert(make_pair(state_list[i], emit_map)); + start_dict_.insert(make_pair(state_list[i],0.0)); + count_dict_.insert(make_pair(state_list[i], 0)); + } + line_num_ = 0; } HmmManager::~HmmManager(){ - if(NULL != train_corpus_){ - delete train_corpus_; - } } bool HmmManager::Init(string train_path, const set& punct_set) { @@ -61,29 +74,86 @@ bool HmmManager::Init(string train_path, const set& punct_set) { next_dict_[word1][word2] += 1; } } + + line_num_++; + vector word_list; // 保存除空格以外的字符 + iutf8string utf8_str(str); + for (int i = 0; i < utf8_str.length(); i++) { + if (utf8_str[i] != " ") { + word_list.push_back(utf8_str[i]); + } + } + vector line_state; + for (size_t i = 0; i < str_vec.size(); i++) { + if (str_vec[i] == "") { + continue; + } + iutf8string utf8_str_item(str_vec[i]); + vector item_state; + getList(utf8_str_item.length(), item_state); + line_state.insert(line_state.end(), item_state.begin(), item_state.end()); + } + if (word_list.size() != line_state.size()) { + log_error("[line = %s]\n", str.c_str()); + } + else { + for (size_t i = 0; i < line_state.size(); i++) { + if (i == 0) { + start_dict_[line_state[0]] += 1; // 记录句子第一个字的状态,用于计算初始状态概率 + count_dict_[line_state[0]] ++; // 记录每个状态的出现次数 + } + else { + trans_dict_[line_state[i - 1]][line_state[i]] += 1; + count_dict_[line_state[i]] ++; + if (emit_dict_[line_state[i]].find(word_list[i]) == emit_dict_[line_state[i]].end()) { + emit_dict_[line_state[i]].insert(make_pair(word_list[i], 0.0)); + } + else { + emit_dict_[line_state[i]][word_list[i]] += 1; + } + } + } + } } train_infile.close(); - - bool ret = train_corpus_->Init(train_path); - if (ret == false) { - log_error("train_corpus init error."); - return ret; - } log_info("total training words length is: %u, next_dict count: %d.", train_cnt_, (int)next_dict_.size()); + map::iterator start_iter = start_dict_.begin(); + for (; start_iter != start_dict_.end(); start_iter++) { // 状态的初始概率 + start_dict_[start_iter->first] = start_dict_[start_iter->first] * 1.0 / line_num_; + } + + map >::iterator trans_iter = trans_dict_.begin(); + for (; trans_iter != trans_dict_.end(); trans_iter++) { // 状态转移概率 + map trans_map = trans_iter->second; + map::iterator trans_iter2 = trans_map.begin(); + for (; trans_iter2 != trans_map.end(); trans_iter2++) { + trans_dict_[trans_iter->first][trans_iter2->first] = trans_dict_[trans_iter->first][trans_iter2->first] / count_dict_[trans_iter->first]; + } + } + + map >::iterator emit_iter = emit_dict_.begin(); + for (; emit_iter != emit_dict_.end(); emit_iter++) { // 发射概率(状态->词语的条件概率) + map emit_map = emit_iter->second; + map::iterator emit_iter2 = emit_map.begin(); + for (; emit_iter2 != emit_map.end(); emit_iter2++) { + double emit_value = emit_dict_[emit_iter->first][emit_iter2->first] / count_dict_[emit_iter->first]; + if (emit_value < min_emit_ && emit_value != 0.0) { + min_emit_ = emit_value; + } + emit_dict_[emit_iter->first][emit_iter2->first] = emit_value; + } + } return true; } void HmmManager::HmmSplit(string str, vector& vec){ - vector pos_list = viterbi(str); + vector pos_list; + viterbi(str, pos_list); string result; iutf8string utf8_str(str); for (size_t i = 0; i < pos_list.size(); i++) { result += utf8_str[i]; - if (pos_list[i] == 'E') { - std::size_t found = result.find_last_of(" "); - string new_word = result.substr(found + 1); - } if (pos_list[i] == 'E' || pos_list[i] == 'S') { result += ' '; } @@ -95,7 +165,7 @@ void HmmManager::HmmSplit(string str, vector& vec){ vec = splitEx(result, " "); } -vector HmmManager::viterbi(string sentence) { +void HmmManager::viterbi(string sentence, vector& output) { iutf8string utf8_str(sentence); vector > viterbi_vec; map > path; @@ -103,11 +173,11 @@ vector HmmManager::viterbi(string sentence) { map prob_map; for (size_t i = 0; i < sizeof(states); i++) { char y = states[i]; - double emit_value = train_corpus_->MinEmit(); - if (train_corpus_->emit_dict[y].find(utf8_str[0]) != train_corpus_->emit_dict[y].end()) { - emit_value = train_corpus_->emit_dict[y].at(utf8_str[0]); + double emit_value = min_emit_; + if (emit_dict_[y].find(utf8_str[0]) != emit_dict_[y].end()) { + emit_value = emit_dict_[y].at(utf8_str[0]); } - prob_map[y] = train_corpus_->start_dict[y] * emit_value; // 在位置0,以y状态为末尾的状态序列的最大概率 + prob_map[y] = start_dict_[y] * emit_value; // 在位置0,以y状态为末尾的状态序列的最大概率 path[y].push_back(y); } viterbi_vec.push_back(prob_map); @@ -120,12 +190,12 @@ vector HmmManager::viterbi(string sentence) { char state = ' '; for (size_t m = 0; m < sizeof(states); m++) { char y0 = states[m]; // 从y0 -> y状态的递归 - //cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << train_corpus.trans_dict[y0][y] << " " << train_corpus.emit_dict[y].at(utf8_str[j]) << endl; - double emit_value = train_corpus_->MinEmit(); - if (train_corpus_->emit_dict[y].find(utf8_str[j]) != train_corpus_->emit_dict[y].end()) { - emit_value = train_corpus_->emit_dict[y].at(utf8_str[j]); + //cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << trans_dict_[y0][y] << " " << emit_dict_[y].at(utf8_str[j]) << endl; + double emit_value = min_emit_; + if (emit_dict_[y].find(utf8_str[j]) != emit_dict_[y].end()) { + emit_value = emit_dict_[y].at(utf8_str[j]); } - double prob = viterbi_vec[j - 1][y0] * train_corpus_->trans_dict[y0][y] * emit_value; + double prob = viterbi_vec[j - 1][y0] * trans_dict_[y0][y] * emit_value; if (prob > max_prob) { max_prob = prob; state = y0; @@ -147,7 +217,23 @@ vector HmmManager::viterbi(string sentence) { state = y; } } - return path[state]; + output.assign(path[state].begin(), path[state].end()); +} + +void HmmManager::getList(uint32_t str_len, vector& output) { + if (str_len == 1) { + output.push_back('S'); + } + else if (str_len == 2) { + output.push_back('B'); + output.push_back('E'); + } + else { + vector M_list(str_len - 2, 'M'); + output.push_back('B'); + output.insert(output.end(), M_list.begin(), M_list.end()); + output.push_back('E'); + } } uint32_t HmmManager::TrainCnt(){ diff --git a/src/comm/segment/hmm_manager.h b/src/comm/segment/hmm_manager.h index 12bd482..216008f 100644 --- a/src/comm/segment/hmm_manager.h +++ b/src/comm/segment/hmm_manager.h @@ -23,7 +23,7 @@ #include #include #include -#include "../trainCorpus.h" +#include using namespace std; class HmmManager{ @@ -35,11 +35,18 @@ public: map >& NextDict(); uint32_t TrainCnt(); private: - vector viterbi(string sentence); + void viterbi(string sentence, vector& output); + void getList(uint32_t str_len, vector& output); private: uint32_t train_cnt_; - TrainCorpus* train_corpus_; map > next_dict_; + + map > trans_dict_; + map > emit_dict_; + map start_dict_; + map count_dict_; + uint32_t line_num_; + double min_emit_; }; diff --git a/src/comm/segment/ngram_segment.cc b/src/comm/segment/ngram_segment.cc index d024c9a..ebf26b5 100644 --- a/src/comm/segment/ngram_segment.cc +++ b/src/comm/segment/ngram_segment.cc @@ -187,16 +187,13 @@ double NgramSegment::calSegProbability(const vector& vec) { } } // 乘以第一个词的概率 - if ((pos == 0 && vec[pos] != "") || (pos == 1 && vec[0] == "")) { + if ((pos == 0 && vec[0] != "") || (pos == 1 && vec[0] == "")) { uint32_t word_freq = 0; WordInfo word_info; if (getWordInfo(vec[pos], 0, word_info)) { word_freq = word_info.word_freq; - p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt()); - } - else { - p += log(1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt()); } + p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt()); } } diff --git a/src/comm/segment/segment.cc b/src/comm/segment/segment.cc index 041d9d7..1a6ef3a 100644 --- a/src/comm/segment/segment.cc +++ b/src/comm/segment/segment.cc @@ -176,9 +176,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector& res_all, vector hmm_manager_->HmmSplit(buf, vec); new_res_all.insert(new_res_all.end(), vec.begin(), vec.end()); } - else { // 是否有这种情况 - new_res_all.push_back(buf); - } } buf = ""; new_res_all.push_back(res_all[i]); @@ -195,9 +192,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector& res_all, vector hmm_manager_->HmmSplit(buf, vec); new_res_all.insert(new_res_all.end(), vec.begin(), vec.end()); } - else { // 是否有这种情况 - new_res_all.push_back(buf); - } buf = ""; } }