from hmm_segment import HMM
import pickle
def train(hmm, path):
# 观察者集合,主要是字以及标点等
words = set()
line_num = -1
with open(path, encoding='utf8') as f:
for line in f:
line_num += 1
line = line.strip()
if not line:
continue
# 获取每一行的字并更新字的集合
word_list = [i for i in line if i != ' ']
words |= set(word_list)
# 每一行按照空格切分,分词的结果
line_list = line.split()
line_state = []
for w in line_list:
line_state.extend(hmm.make_label(w))
assert len(word_list) == len(line_state)
# ['B', 'M', 'M', 'M', 'E', 'S']
for k, v in enumerate(line_state):
hmm.Count_dic[v] += 1 # 统计状态出现的次数
if k == 0:
hmm.Pi_dic[v] += 1 # 每个句子的第一个字的状态,用于计算初始状态概率
else:
# {'B': {'B': 0.0, 'M': 0.0, 'E': 0.0, 'S': 0.0}, ...}
# A矩阵更新:第二个状态"M", 获取前一个状态"B", B -> M :加一
# {'B': {'B': 0.0, 'M': 1.0, 'E': 0.0, 'S': 0.0}, ...}
hmm.A_dic[line_state[k - 1]][v] += 1 # 计算转移概率
# {'B': {}, 'M': {}, 'E': {}, 'S': {}}
# ['1', '9', '8', '6', '年', ',']
# {'B': {}, 'M': {'9': 1.0}, 'E': {}, 'S': {}}
hmm.B_dic[line_state[k]][word_list[k]] = hmm.B_dic[line_state[k]].get(word_list[k], 0) + 1.0 # 计算发射概率
hmm.line_num = line_num
calculate_probability(hmm)
def calculate_probability(hmm):
# A_dic
# {'B': {'B': 0.0, 'M': 162066.0, 'E': 1226466.0, 'S': 0.0},
# 'M': {'B': 0.0, 'M': 62332.0, 'E': 162066.0, 'S': 0.0},
# 'E': {'B': 651128.0, 'M': 0.0, 'E': 0.0, 'S': 737404.0},
# 'S': {'B': 563988.0, 'M': 0.0, 'E': 0.0, 'S': 747969.0}
# }
# B_dic
# {'B': {'中': 12812.0, '儿': 464.0, '踏': 62.0},
# 'M': {'中': 12812.0, '儿': 464.0, '踏': 62.0},
# 'E': {'中': 12812.0, '儿': 464.0, '踏': 62.0},
# 'S': {'中': 12812.0, '儿': 464.0, '踏': 62.0},
# }
# Count_dic: {'B': 1388532, 'M': 224398, 'E': 1388532, 'S': 1609916}
# 求概率,句首状态概率
hmm.Pi_dic = {k: v * 1.0 / hmm.line_num for k, v in hmm.Pi_dic.items()}
# 求概率,转移状态概率
hmm.A_dic = {k: {k1: v1 / hmm.Count_dic[k] for k1, v1 in v.items()} for k, v in hmm.A_dic.items()}
# 加1平滑
hmm.B_dic = {k: {k1: (v1 + 1) / hmm.Count_dic[k] for k1, v1 in v.items()} for k, v in hmm.B_dic.items()}
with open(hmm.model_file, 'wb') as f:
pickle.dump(hmm.A_dic, f)
pickle.dump(hmm.B_dic, f)
pickle.dump(hmm.Pi_dic, f)
def viterbi(text, states, Pi_dic, A_dic, B_dic):
V = [{}]
path = {}
"""
P = p(y1)p(x1|y1) 连乘 p(yi|yi-1)p(xi|yi)
"""
# text: 这是一个非常棒的方案!
# 获取起始概率
# V: [{'B': 0.003291232115235236, 'M': 0.0, 'E': 0.0, 'S': 0.0012044407157278893}]
# path: {'B': ['B'], 'M': ['M'], 'E': ['E'], 'S': ['S']}
for y in states:
# V[0][B] = p(y1) * p(x1|y1)
V[0][y] = Pi_dic[y] * B_dic[y].get(text[0], 0)
path[y] = [y]
# print(path)
for t in range(1, len(text)):
V.append({})
new_path = {}
# 检验训练的发射概率矩阵中是否有该字
neverSeen = text[t] not in B_dic['S'].keys() and text[t] not in B_dic['M'].keys() and \
text[t] not in B_dic['E'].keys() and text[t] not in B_dic['B'].keys()
# states: ['B', 'M', 'E', 'S']
for y in states:
# P = p(yi|yi-1)p(xi|yi)
# p(xi|yi): 从 “B” 中找 “爱” 的概率
emitP = B_dic[y].get(text[t], 0) if not neverSeen else 1.0 # 设置未知字单独成词
# V[t - 1][y0]: 前一个状态的结果 * p(yi|yi-1) * p(xi|yi)
prob, state = max([(V[t - 1][y0] * A_dic[y0].get(y, 0) * emitP, y0) for y0 in states if V[t - 1][y0] > 0])
temp = []
for y0 in states:
if V[t - 1][y0] > 0:
# print(A_dic[y0].get(y, 0))
temp.append((V[t - 1][y0] * A_dic[y0].get(y, 0) * emitP, y0))
# print(temp)
# print(max(temp))
# print((prob, state))
# [(), (), (), ()]
V[t][y] = prob
new_path[y] = path[state] + [y]
# print(new_path)
# break
path = new_path
# V: V[i] 到第 i 个字的最大概率
# [{'B': 0, 'M': 0, 'E': 0, 'S': 0}, 这
# {'B': 0, 'M': 0, 'E': 0, 'S': 0}, 是
# {'B': 0, 'M': 0, 'E': 0, 'S': 0}, 一
# {'B': 0, 'M': 0, 'E': 0, 'S': 0} 个
# ]
# print(V)
# print(path)
# [{'B': 0.003291232115235236, 'M': 0.0, 'E': 0.0, 'S': 0.0012044407157278893},
# [{'B': 1.3735188283184088e-07, 'M': 1.2667977122628068e-07,
# 'E': 1.4243040542188446e-05, 'S': 7.866572112633136e-06}]
# {'B': ['S', 'B'], 'M': ['B', 'M'], 'E': ['B', 'E'], 'S': ['S', 'S']}
# break
# 现在判断最后一个时刻应该选哪个
# 如果最后一个字在词中的概率大于是独立的概率(判断是不是标点符号吧?i guess)
# 则 直接让其在 词中 或者词尾中比较
# 直接比较存在v[上一个时刻][这个时刻的四个状态]谁比较大就可以了
if B_dic['M'].get(text[-1], 0) > B_dic['S'].get(text[-1], 0):
prob, state = max([(V[len(text) - 1][y], y) for y in ('E', 'M')])
else:
prob, state = max([(V[len(text) - 1][y], y) for y in states])
print(V)
print(path)
return prob, path[state]
if __name__ == '__main__':
hmm = HMM()
hmm.try_load_model(True)
# 初始化状态转移矩阵
hmm.init_parameters()
train(hmm, './data/trainCorpus.txt_utf8')
# print(hmm.A_dic)
# print(hmm.Pi_dic)
text = '这是一个'
viterbi(text, hmm.state_list, hmm.Pi_dic, hmm.A_dic, hmm.B_dic)
# res = hmm.cut(text)
# print(text)
# print(str(list(res)))
评论0