isearch/comm/trainCorpus.cc

129 lines
4.0 KiB
C++
Raw Normal View History

2021-03-19 18:38:34 +08:00
#include "trainCorpus.h"
#include "log.h"
#include "utf8_str.h"
#include <fstream>
#include <iostream>
#include <vector>
#include <string>
#include <stdint.h>
using namespace std;
vector<char> getList(uint32_t str_len) {
vector<char> output;
if (str_len == 1) {
output.push_back('S');
}
else if (str_len == 2) {
output.push_back('B');
output.push_back('E');
}
else {
vector<char> M_list(str_len - 2, 'M');
output.push_back('B');
output.insert(output.end(), M_list.begin(), M_list.end());
output.push_back('E');
}
return output;
}
TrainCorpus::TrainCorpus() {
min_emit = 1.0;
char state[4] = { 'B','M','E','S' };
vector<char> state_list(state, state + 4);
for (size_t i = 0; i < state_list.size(); i++) {
map<char, double> trans_map;
trans_dict.insert(make_pair(state_list[i], trans_map));
for (size_t j = 0; j < state_list.size(); j++) {
trans_dict[state_list[i]].insert(make_pair(state_list[j], 0.0));
}
}
for (size_t i = 0; i < state_list.size(); i++) {
map<string, double> emit_map;
emit_dict.insert(make_pair(state_list[i], emit_map));
start_dict.insert(make_pair(state_list[i],0.0));
count_dict.insert(make_pair(state_list[i], 0));
}
line_num = 0;
}
bool TrainCorpus::Init(string train_path) {
ifstream train_infile; // ѵ<><D1B5><EFBFBD>ļ<EFBFBD><C4BC>Կո<D4BF>Ϊ<EFBFBD>ָ<EFBFBD><D6B8><EFBFBD>
train_infile.open(train_path.c_str());
if (train_infile.is_open() == false) {
log_error("open file error: %s.\n", train_path.c_str());
return false;
}
string str;
while (getline(train_infile, str))
{
line_num++;
vector<string> word_list; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ո<EFBFBD><D5B8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ַ<EFBFBD>
iutf8string utf8_str(str);
for (int i = 0; i < utf8_str.length(); i++) {
if (utf8_str[i] != " ") {
word_list.push_back(utf8_str[i]);
}
}
vector<char> line_state;
vector<string> str_vec = splitEx(str, " ");
for (size_t i = 0; i < str_vec.size(); i++) {
if (str_vec[i] == "") {
continue;
}
iutf8string utf8_str_item(str_vec[i]);
vector<char> item_state = getList(utf8_str_item.length());
line_state.insert(line_state.end(), item_state.begin(), item_state.end());
}
if (word_list.size() != line_state.size()) {
log_error("[line = %s]\n", str.c_str());
}
else {
for (size_t i = 0; i < line_state.size(); i++) {
if (i == 0) {
start_dict[line_state[0]] += 1; // <20><>¼<EFBFBD><C2BC><EFBFBD>ӵ<EFBFBD>һ<EFBFBD><D2BB><EFBFBD>ֵ<EFBFBD>״̬<D7B4><CCAC><EFBFBD><EFBFBD><EFBFBD>ڼ<EFBFBD><DABC><EFBFBD><EFBFBD><EFBFBD>ʼ״̬<D7B4><CCAC><EFBFBD><EFBFBD>
count_dict[line_state[0]] ++; // <20><>¼ÿ<C2BC><C3BF>״̬<D7B4>ij<EFBFBD><C4B3>ִ<EFBFBD><D6B4><EFBFBD>
}
else {
trans_dict[line_state[i - 1]][line_state[i]] += 1;
count_dict[line_state[i]] ++;
if (emit_dict[line_state[i]].find(word_list[i]) == emit_dict[line_state[i]].end()) {
emit_dict[line_state[i]].insert(make_pair(word_list[i], 0.0));
}
else {
emit_dict[line_state[i]][word_list[i]] += 1;
}
}
}
}
}
train_infile.close();
map<char, double>::iterator start_iter = start_dict.begin();
for (; start_iter != start_dict.end(); start_iter++) { // ״̬<D7B4>ij<EFBFBD>ʼ<EFBFBD><CABC><EFBFBD><EFBFBD>
start_dict[start_iter->first] = start_dict[start_iter->first] * 1.0 / line_num;
}
map<char, map<char, double> >::iterator trans_iter = trans_dict.begin();
for (; trans_iter != trans_dict.end(); trans_iter++) { // ״̬ת<CCAC>Ƹ<EFBFBD><C6B8><EFBFBD>
map<char, double> trans_map = trans_iter->second;
map<char, double>::iterator trans_iter2 = trans_map.begin();
for (; trans_iter2 != trans_map.end(); trans_iter2++) {
trans_dict[trans_iter->first][trans_iter2->first] = trans_dict[trans_iter->first][trans_iter2->first] / count_dict[trans_iter->first];
}
}
map<char, map<string, double> >::iterator emit_iter = emit_dict.begin();
for (; emit_iter != emit_dict.end(); emit_iter++) { // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>(״̬-><3E><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>)
map<string, double> emit_map = emit_iter->second;
map<string, double>::iterator emit_iter2 = emit_map.begin();
for (; emit_iter2 != emit_map.end(); emit_iter2++) {
double emit_value = emit_dict[emit_iter->first][emit_iter2->first] / count_dict[emit_iter->first];
if (emit_value < min_emit && emit_value != 0.0) {
min_emit = emit_value;
}
emit_dict[emit_iter->first][emit_iter2->first] = emit_value;
}
}
return true;
}