code optimization

This commit is contained in:
shzhulin3 2021-06-23 14:47:30 +08:00
parent 25d938f3e6
commit a942da2a5c
6 changed files with 131 additions and 66 deletions

View File

@ -23,15 +23,13 @@ void BmmSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<strin
break;
iutf8string key = phrase_sub.utf8substr(j, phrase_sub.length()-j);
if (wordValid(key.stlstring(), appid) == true) {
vector<string>::iterator iter = bmm_list.begin();
bmm_list.insert(iter, key.stlstring());
bmm_list.insert(bmm_list.begin(), key.stlstring());
i -= key.length();
break;
}
}
if (j == phrase_sub.length() - 1) {
vector<string>::iterator iter = bmm_list.begin();
bmm_list.insert(iter, "" + phrase_sub[j]);
bmm_list.insert(bmm_list.begin(), "" + phrase_sub[j]);
i--;
}
}

View File

@ -15,29 +15,13 @@ void DagSegment::ConcreteSplit(iutf8string& sentence, uint32_t appid, vector<str
getDag(sentence, appid, dag_map);
map<uint32_t, RouteValue> route;
calc(sentence, dag_map, route, appid);
iutf8string utf8_str(sentence.stlstring());
uint32_t N = utf8_str.length();
uint32_t N = sentence.length();
uint32_t i = 0;
string buf = "";
while (i < N) {
uint32_t j = route[i].idx + 1;
string l_word = utf8_str.substr(i, j - i);
if (isAllAlphaOrDigit(l_word)) {
buf += l_word;
i = j;
}
else {
if (!buf.empty()) {
vec.push_back(buf);
buf = "";
}
vec.push_back(l_word);
i = j;
}
}
if (!buf.empty()) {
vec.push_back(buf);
buf = "";
string l_word = sentence.substr(i, j - i);
vec.push_back(l_word);
i = j;
}
return;
@ -81,12 +65,11 @@ void DagSegment::calc(iutf8string& utf8_str, const map<uint32_t, vector<uint32_t
map<uint32_t, WordInfo> wordInfo = word_dict_[word];
if (wordInfo.find(0) != wordInfo.end()) {
word_info = wordInfo[0];
word_freq = word_info.word_freq;
}
if (wordInfo.find(appid) != wordInfo.end()) {
word_info = wordInfo[appid];
word_freq = word_info.word_freq;
}
word_freq = word_info.word_freq;
}
double route_value = log(word_freq) - logtotal + route[vec[t] + 1].max_route;
if (route_value > max_route) {

View File

@ -5,13 +5,26 @@
#include "../utf8_str.h"
HmmManager::HmmManager(){
train_corpus_ = new TrainCorpus();
min_emit_ = 1.0;
char state[4] = { 'B','M','E','S' };
vector<char> state_list(state, state + 4);
for (size_t i = 0; i < state_list.size(); i++) {
map<char, double> trans_map;
trans_dict_.insert(make_pair(state_list[i], trans_map));
for (size_t j = 0; j < state_list.size(); j++) {
trans_dict_[state_list[i]].insert(make_pair(state_list[j], 0.0));
}
}
for (size_t i = 0; i < state_list.size(); i++) {
map<string, double> emit_map;
emit_dict_.insert(make_pair(state_list[i], emit_map));
start_dict_.insert(make_pair(state_list[i],0.0));
count_dict_.insert(make_pair(state_list[i], 0));
}
line_num_ = 0;
}
HmmManager::~HmmManager(){
if(NULL != train_corpus_){
delete train_corpus_;
}
}
bool HmmManager::Init(string train_path, const set<string>& punct_set) {
@ -61,29 +74,86 @@ bool HmmManager::Init(string train_path, const set<string>& punct_set) {
next_dict_[word1][word2] += 1;
}
}
line_num_++;
vector<string> word_list; // 保存除空格以外的字符
iutf8string utf8_str(str);
for (int i = 0; i < utf8_str.length(); i++) {
if (utf8_str[i] != " ") {
word_list.push_back(utf8_str[i]);
}
}
vector<char> line_state;
for (size_t i = 0; i < str_vec.size(); i++) {
if (str_vec[i] == "") {
continue;
}
iutf8string utf8_str_item(str_vec[i]);
vector<char> item_state;
getList(utf8_str_item.length(), item_state);
line_state.insert(line_state.end(), item_state.begin(), item_state.end());
}
if (word_list.size() != line_state.size()) {
log_error("[line = %s]\n", str.c_str());
}
else {
for (size_t i = 0; i < line_state.size(); i++) {
if (i == 0) {
start_dict_[line_state[0]] += 1; // 记录句子第一个字的状态,用于计算初始状态概率
count_dict_[line_state[0]] ++; // 记录每个状态的出现次数
}
else {
trans_dict_[line_state[i - 1]][line_state[i]] += 1;
count_dict_[line_state[i]] ++;
if (emit_dict_[line_state[i]].find(word_list[i]) == emit_dict_[line_state[i]].end()) {
emit_dict_[line_state[i]].insert(make_pair(word_list[i], 0.0));
}
else {
emit_dict_[line_state[i]][word_list[i]] += 1;
}
}
}
}
}
train_infile.close();
bool ret = train_corpus_->Init(train_path);
if (ret == false) {
log_error("train_corpus init error.");
return ret;
}
log_info("total training words length is: %u, next_dict count: %d.", train_cnt_, (int)next_dict_.size());
map<char, double>::iterator start_iter = start_dict_.begin();
for (; start_iter != start_dict_.end(); start_iter++) { // 状态的初始概率
start_dict_[start_iter->first] = start_dict_[start_iter->first] * 1.0 / line_num_;
}
map<char, map<char, double> >::iterator trans_iter = trans_dict_.begin();
for (; trans_iter != trans_dict_.end(); trans_iter++) { // 状态转移概率
map<char, double> trans_map = trans_iter->second;
map<char, double>::iterator trans_iter2 = trans_map.begin();
for (; trans_iter2 != trans_map.end(); trans_iter2++) {
trans_dict_[trans_iter->first][trans_iter2->first] = trans_dict_[trans_iter->first][trans_iter2->first] / count_dict_[trans_iter->first];
}
}
map<char, map<string, double> >::iterator emit_iter = emit_dict_.begin();
for (; emit_iter != emit_dict_.end(); emit_iter++) { // 发射概率(状态->词语的条件概率)
map<string, double> emit_map = emit_iter->second;
map<string, double>::iterator emit_iter2 = emit_map.begin();
for (; emit_iter2 != emit_map.end(); emit_iter2++) {
double emit_value = emit_dict_[emit_iter->first][emit_iter2->first] / count_dict_[emit_iter->first];
if (emit_value < min_emit_ && emit_value != 0.0) {
min_emit_ = emit_value;
}
emit_dict_[emit_iter->first][emit_iter2->first] = emit_value;
}
}
return true;
}
void HmmManager::HmmSplit(string str, vector<string>& vec){
vector<char> pos_list = viterbi(str);
vector<char> pos_list;
viterbi(str, pos_list);
string result;
iutf8string utf8_str(str);
for (size_t i = 0; i < pos_list.size(); i++) {
result += utf8_str[i];
if (pos_list[i] == 'E') {
std::size_t found = result.find_last_of(" ");
string new_word = result.substr(found + 1);
}
if (pos_list[i] == 'E' || pos_list[i] == 'S') {
result += ' ';
}
@ -95,7 +165,7 @@ void HmmManager::HmmSplit(string str, vector<string>& vec){
vec = splitEx(result, " ");
}
vector<char> HmmManager::viterbi(string sentence) {
void HmmManager::viterbi(string sentence, vector<char>& output) {
iutf8string utf8_str(sentence);
vector<map<char, double> > viterbi_vec;
map<char, vector<char> > path;
@ -103,11 +173,11 @@ vector<char> HmmManager::viterbi(string sentence) {
map<char, double> prob_map;
for (size_t i = 0; i < sizeof(states); i++) {
char y = states[i];
double emit_value = train_corpus_->MinEmit();
if (train_corpus_->emit_dict[y].find(utf8_str[0]) != train_corpus_->emit_dict[y].end()) {
emit_value = train_corpus_->emit_dict[y].at(utf8_str[0]);
double emit_value = min_emit_;
if (emit_dict_[y].find(utf8_str[0]) != emit_dict_[y].end()) {
emit_value = emit_dict_[y].at(utf8_str[0]);
}
prob_map[y] = train_corpus_->start_dict[y] * emit_value; // 在位置0以y状态为末尾的状态序列的最大概率
prob_map[y] = start_dict_[y] * emit_value; // 在位置0以y状态为末尾的状态序列的最大概率
path[y].push_back(y);
}
viterbi_vec.push_back(prob_map);
@ -120,12 +190,12 @@ vector<char> HmmManager::viterbi(string sentence) {
char state = ' ';
for (size_t m = 0; m < sizeof(states); m++) {
char y0 = states[m]; // 从y0 -> y状态的递归
//cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << train_corpus.trans_dict[y0][y] << " " << train_corpus.emit_dict[y].at(utf8_str[j]) << endl;
double emit_value = train_corpus_->MinEmit();
if (train_corpus_->emit_dict[y].find(utf8_str[j]) != train_corpus_->emit_dict[y].end()) {
emit_value = train_corpus_->emit_dict[y].at(utf8_str[j]);
//cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << trans_dict_[y0][y] << " " << emit_dict_[y].at(utf8_str[j]) << endl;
double emit_value = min_emit_;
if (emit_dict_[y].find(utf8_str[j]) != emit_dict_[y].end()) {
emit_value = emit_dict_[y].at(utf8_str[j]);
}
double prob = viterbi_vec[j - 1][y0] * train_corpus_->trans_dict[y0][y] * emit_value;
double prob = viterbi_vec[j - 1][y0] * trans_dict_[y0][y] * emit_value;
if (prob > max_prob) {
max_prob = prob;
state = y0;
@ -147,7 +217,23 @@ vector<char> HmmManager::viterbi(string sentence) {
state = y;
}
}
return path[state];
output.assign(path[state].begin(), path[state].end());
}
void HmmManager::getList(uint32_t str_len, vector<char>& output) {
if (str_len == 1) {
output.push_back('S');
}
else if (str_len == 2) {
output.push_back('B');
output.push_back('E');
}
else {
vector<char> M_list(str_len - 2, 'M');
output.push_back('B');
output.insert(output.end(), M_list.begin(), M_list.end());
output.push_back('E');
}
}
uint32_t HmmManager::TrainCnt(){

View File

@ -23,7 +23,7 @@
#include <set>
#include <vector>
#include <map>
#include "../trainCorpus.h"
#include <stdint.h>
using namespace std;
class HmmManager{
@ -35,11 +35,18 @@ public:
map<string, map<string, int> >& NextDict();
uint32_t TrainCnt();
private:
vector<char> viterbi(string sentence);
void viterbi(string sentence, vector<char>& output);
void getList(uint32_t str_len, vector<char>& output);
private:
uint32_t train_cnt_;
TrainCorpus* train_corpus_;
map<string, map<string, int> > next_dict_;
map<char, map<char, double> > trans_dict_;
map<char, map<string, double> > emit_dict_;
map<char, double> start_dict_;
map<char, uint32_t> count_dict_;
uint32_t line_num_;
double min_emit_;
};

View File

@ -187,16 +187,13 @@ double NgramSegment::calSegProbability(const vector<string>& vec) {
}
}
// 乘以第一个词的概率
if ((pos == 0 && vec[pos] != "<BEG>") || (pos == 1 && vec[0] == "<BEG>")) {
if ((pos == 0 && vec[0] != "<BEG>") || (pos == 1 && vec[0] == "<BEG>")) {
uint32_t word_freq = 0;
WordInfo word_info;
if (getWordInfo(vec[pos], 0, word_info)) {
word_freq = word_info.word_freq;
p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
}
else {
p += log(1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
}
p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
}
}

View File

@ -176,9 +176,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector<string>& res_all, vector
hmm_manager_->HmmSplit(buf, vec);
new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
}
else { // 是否有这种情况
new_res_all.push_back(buf);
}
}
buf = "";
new_res_all.push_back(res_all[i]);
@ -195,9 +192,6 @@ void Segment::dealByHmmMgr(uint32_t appid, const vector<string>& res_all, vector
hmm_manager_->HmmSplit(buf, vec);
new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
}
else { // 是否有这种情况
new_res_all.push_back(buf);
}
buf = "";
}
}