129 lines
4.0 KiB
C++
129 lines
4.0 KiB
C++
|
#include "trainCorpus.h"
|
|||
|
#include "log.h"
|
|||
|
#include "utf8_str.h"
|
|||
|
#include <fstream>
|
|||
|
#include <iostream>
|
|||
|
#include <vector>
|
|||
|
#include <string>
|
|||
|
#include <stdint.h>
|
|||
|
using namespace std;
|
|||
|
|
|||
|
vector<char> getList(uint32_t str_len) {
|
|||
|
vector<char> output;
|
|||
|
if (str_len == 1) {
|
|||
|
output.push_back('S');
|
|||
|
}
|
|||
|
else if (str_len == 2) {
|
|||
|
output.push_back('B');
|
|||
|
output.push_back('E');
|
|||
|
}
|
|||
|
else {
|
|||
|
vector<char> M_list(str_len - 2, 'M');
|
|||
|
output.push_back('B');
|
|||
|
output.insert(output.end(), M_list.begin(), M_list.end());
|
|||
|
output.push_back('E');
|
|||
|
}
|
|||
|
return output;
|
|||
|
}
|
|||
|
|
|||
|
TrainCorpus::TrainCorpus() {
|
|||
|
min_emit = 1.0;
|
|||
|
char state[4] = { 'B','M','E','S' };
|
|||
|
vector<char> state_list(state, state + 4);
|
|||
|
for (size_t i = 0; i < state_list.size(); i++) {
|
|||
|
map<char, double> trans_map;
|
|||
|
trans_dict.insert(make_pair(state_list[i], trans_map));
|
|||
|
for (size_t j = 0; j < state_list.size(); j++) {
|
|||
|
trans_dict[state_list[i]].insert(make_pair(state_list[j], 0.0));
|
|||
|
}
|
|||
|
}
|
|||
|
for (size_t i = 0; i < state_list.size(); i++) {
|
|||
|
map<string, double> emit_map;
|
|||
|
emit_dict.insert(make_pair(state_list[i], emit_map));
|
|||
|
start_dict.insert(make_pair(state_list[i],0.0));
|
|||
|
count_dict.insert(make_pair(state_list[i], 0));
|
|||
|
}
|
|||
|
line_num = 0;
|
|||
|
}
|
|||
|
|
|||
|
bool TrainCorpus::Init(string train_path) {
|
|||
|
ifstream train_infile; // ѵ<><D1B5><EFBFBD>ļ<EFBFBD><C4BC>Կո<D4BF>Ϊ<EFBFBD>ָ<EFBFBD><D6B8><EFBFBD>
|
|||
|
train_infile.open(train_path.c_str());
|
|||
|
if (train_infile.is_open() == false) {
|
|||
|
log_error("open file error: %s.\n", train_path.c_str());
|
|||
|
return false;
|
|||
|
}
|
|||
|
string str;
|
|||
|
while (getline(train_infile, str))
|
|||
|
{
|
|||
|
line_num++;
|
|||
|
vector<string> word_list; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ո<EFBFBD><D5B8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ַ<EFBFBD>
|
|||
|
iutf8string utf8_str(str);
|
|||
|
for (int i = 0; i < utf8_str.length(); i++) {
|
|||
|
if (utf8_str[i] != " ") {
|
|||
|
word_list.push_back(utf8_str[i]);
|
|||
|
}
|
|||
|
}
|
|||
|
vector<char> line_state;
|
|||
|
vector<string> str_vec = splitEx(str, " ");
|
|||
|
for (size_t i = 0; i < str_vec.size(); i++) {
|
|||
|
if (str_vec[i] == "") {
|
|||
|
continue;
|
|||
|
}
|
|||
|
iutf8string utf8_str_item(str_vec[i]);
|
|||
|
vector<char> item_state = getList(utf8_str_item.length());
|
|||
|
line_state.insert(line_state.end(), item_state.begin(), item_state.end());
|
|||
|
}
|
|||
|
if (word_list.size() != line_state.size()) {
|
|||
|
log_error("[line = %s]\n", str.c_str());
|
|||
|
}
|
|||
|
else {
|
|||
|
for (size_t i = 0; i < line_state.size(); i++) {
|
|||
|
if (i == 0) {
|
|||
|
start_dict[line_state[0]] += 1; // <20><>¼<EFBFBD><C2BC><EFBFBD>ӵ<EFBFBD>һ<EFBFBD><D2BB><EFBFBD>ֵ<EFBFBD>״̬<D7B4><CCAC><EFBFBD><EFBFBD><EFBFBD>ڼ<EFBFBD><DABC><EFBFBD><EFBFBD><EFBFBD>ʼ״̬<D7B4><CCAC><EFBFBD><EFBFBD>
|
|||
|
count_dict[line_state[0]] ++; // <20><>¼ÿ<C2BC><C3BF>״̬<D7B4>ij<EFBFBD><C4B3>ִ<EFBFBD><D6B4><EFBFBD>
|
|||
|
}
|
|||
|
else {
|
|||
|
trans_dict[line_state[i - 1]][line_state[i]] += 1;
|
|||
|
count_dict[line_state[i]] ++;
|
|||
|
if (emit_dict[line_state[i]].find(word_list[i]) == emit_dict[line_state[i]].end()) {
|
|||
|
emit_dict[line_state[i]].insert(make_pair(word_list[i], 0.0));
|
|||
|
}
|
|||
|
else {
|
|||
|
emit_dict[line_state[i]][word_list[i]] += 1;
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
train_infile.close();
|
|||
|
|
|||
|
map<char, double>::iterator start_iter = start_dict.begin();
|
|||
|
for (; start_iter != start_dict.end(); start_iter++) { // ״̬<D7B4>ij<EFBFBD>ʼ<EFBFBD><CABC><EFBFBD><EFBFBD>
|
|||
|
start_dict[start_iter->first] = start_dict[start_iter->first] * 1.0 / line_num;
|
|||
|
}
|
|||
|
|
|||
|
map<char, map<char, double> >::iterator trans_iter = trans_dict.begin();
|
|||
|
for (; trans_iter != trans_dict.end(); trans_iter++) { // ״̬ת<CCAC>Ƹ<EFBFBD><C6B8><EFBFBD>
|
|||
|
map<char, double> trans_map = trans_iter->second;
|
|||
|
map<char, double>::iterator trans_iter2 = trans_map.begin();
|
|||
|
for (; trans_iter2 != trans_map.end(); trans_iter2++) {
|
|||
|
trans_dict[trans_iter->first][trans_iter2->first] = trans_dict[trans_iter->first][trans_iter2->first] / count_dict[trans_iter->first];
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
map<char, map<string, double> >::iterator emit_iter = emit_dict.begin();
|
|||
|
for (; emit_iter != emit_dict.end(); emit_iter++) { // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>(״̬-><3E><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>)
|
|||
|
map<string, double> emit_map = emit_iter->second;
|
|||
|
map<string, double>::iterator emit_iter2 = emit_map.begin();
|
|||
|
for (; emit_iter2 != emit_map.end(); emit_iter2++) {
|
|||
|
double emit_value = emit_dict[emit_iter->first][emit_iter2->first] / count_dict[emit_iter->first];
|
|||
|
if (emit_value < min_emit && emit_value != 0.0) {
|
|||
|
min_emit = emit_value;
|
|||
|
}
|
|||
|
emit_dict[emit_iter->first][emit_iter2->first] = emit_value;
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return true;
|
|||
|
}
|