Statistical-Learning-Method.../LDA/LDA.ipynb

287 lines
12 KiB
Plaintext
Raw Normal View History

2021-01-26 17:06:12 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Topics:\n",
"['tech', 'business', 'sport', 'entertainment', 'politics']\n",
"1/10\n",
"2/10\n",
"3/10\n",
"4/10\n",
"5/10\n",
"6/10\n",
"7/10\n",
"8/10\n",
"9/10\n",
"10/10\n",
"Topic 1: said game england would time first back play last good\n",
"Topic 2: said year would economy growth also economic bank government could\n",
"Topic 3: said year games sales company also market last firm 2004\n",
"Topic 4: film said music best also people year show number digital\n",
"Topic 5: said would people government labour election party blair could also\n",
"Time: 7620.509902954102\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import string\n",
"from nltk.corpus import stopwords\n",
"import time\n",
"\n",
"\n",
"#定义加载数据的函数\n",
"def load_data(file, K):\n",
" '''\n",
" INPUT:\n",
" file - (str) 数据文件的路径\n",
" K - (int) 设定的话题数\n",
" \n",
" OUTPUT:\n",
" org_topics - (list) 原始话题标签列表\n",
" text - (list) 文本列表\n",
" words - (list) 单词列表\n",
" alpha - (list) 话题概率分布,模型超参数\n",
" beta - (list) 单词概率分布,模型超参数\n",
" \n",
" '''\n",
" df = pd.read_csv(file) #读取文件\n",
" org_topics = df['category'].unique().tolist() #保存文本原始的话题标签\n",
" M = df.shape[0] #文本数\n",
" alpha = np.zeros(K) #alpha是LDA模型的一个超参数是对话题概率的预估计这里取文本数据中各话题的比例作为alpha值实际可以通过模型训练得到\n",
" beta = np.zeros(1000) #beta是LDA模型的另一个超参数是词汇表中单词的概率分布这里取各单词在所有文本中的比例作为beta值实际也可以通过模型训练得到\n",
" #计算各话题的比例作为alpha值\n",
" for k, topic in enumerate(org_topics):\n",
" alpha[k] = df[df['category'] == topic].shape[0] / M\n",
" df.drop('category', axis=1, inplace=True)\n",
" n = df.shape[0] #n为文本数量\n",
" text = []\n",
" words = []\n",
" for i in df['text'].values:\n",
" t = i.translate(str.maketrans('', '', string.punctuation)) #去除文本中的标点符号\n",
" t = [j for j in t.split() if j not in stopwords.words('english')] #去除文本中的停止词\n",
" t = [j for j in t if len(j) > 3] #长度小于等于3的单词大多是无意义的直接去除\n",
" text.append(t) #将处理后的文本保存到文本列表中\n",
" words.extend(set(t)) #将文本中所包含的单词保存到单词列表中\n",
" words = list(set(words)) #去除单词列表中的重复单词\n",
" words_cnt = np.zeros(len(words)) #用来保存单词的出现频次\n",
" #循环计算words列表中各单词出现的词频\n",
" for i in range(len(text)):\n",
" t = text[i] #取出第i条文本\n",
" for w in t:\n",
" ind = words.index(w) #取出第i条文本中的第t个单词在单词列表中的索引\n",
" words_cnt[ind] += 1 #对应位置的单词出现频次加一\n",
" sort_inds = np.argsort(words_cnt)[::-1] #对单词出现频次降序排列后取出其索引值\n",
" words = [words[ind] for ind in sort_inds[:1000]] #将出现频次前1000的单词保存到words列表\n",
" #去除文本text中不在词汇表words中的单词\n",
" for i in range(len(text)):\n",
" t = []\n",
" for w in text[i]:\n",
" if w in words:\n",
" ind = words.index(w)\n",
" t.append(w)\n",
" beta[ind] += 1 #统计各单词在文本中的出现频次\n",
" text[i] = t\n",
" beta /= np.sum(beta) #除以文本的总单词数得到各单词所占比例作为beta值\n",
" return org_topics, text, words, alpha, beta\n",
"\n",
"\n",
"#定义潜在狄利克雷分配函数采用收缩的吉布斯抽样算法估计模型的参数theta和phi\n",
"def do_lda(text, words, alpha, beta, K, iters):\n",
" '''\n",
" INPUT:\n",
" text - (list) 文本列表\n",
" words - (list) 单词列表\n",
" alpha - (list) 话题概率分布,模型超参数\n",
" beta - (list) 单词概率分布,模型超参数\n",
" K - (int) 设定的话题数\n",
" iters - (int) 设定的迭代次数\n",
" \n",
" OUTPUT:\n",
" theta - (array) 话题的条件概率分布p(zk|dj)这里写成p(zk|dj)是为了和PLSA模型那一章的符号统一一下方便对照着看\n",
" phi - (array) 单词的条件概率分布p(wi|zk)\n",
" \n",
" '''\n",
" M = len(text) #文本数\n",
" V = len(words) #单词数\n",
" N_MK = np.zeros((M, K)) #文本-话题计数矩阵\n",
" N_KV = np.zeros((K, V)) #话题-单词计数矩阵\n",
" N_M = np.zeros(M) #文本计数向量\n",
" N_K = np.zeros(K) #话题计数向量\n",
" Z_MN = [] #用来保存每条文本的每个单词所在位置处抽样得到的话题\n",
" #算法20.2的步骤(2),对每个文本的所有单词抽样产生话题,并进行计数\n",
" for m in range(M):\n",
" zm = []\n",
" t = text[m]\n",
" for n, w in enumerate(t):\n",
" v = words.index(w)\n",
" z = np.random.randint(K)\n",
" zm.append(z)\n",
" N_MK[m, z] += 1\n",
" N_M[m] += 1\n",
" N_KV[z, v] += 1\n",
" N_K[z] += 1\n",
" Z_MN.append(zm)\n",
" #算法20.2的步骤(3),多次迭代进行吉布斯抽样\n",
" for i in range(iters):\n",
" print('{}/{}'.format(i+1, iters))\n",
" for m in range(M):\n",
" t = text[m]\n",
" for n, w in enumerate(t):\n",
" v = words.index(w)\n",
" z = Z_MN[m][n]\n",
" N_MK[m, z] -= 1\n",
" N_M[m] -= 1\n",
" N_KV[z][v] -= 1\n",
" N_K[z] -= 1\n",
" p = [] #用来保存对K个话题的条件分布p(zi|z_i,w,alpha,beta)的计算结果\n",
" sums_k = 0 \n",
" for k in range(K):\n",
" p_zk = (N_KV[k][v] + beta[v]) * (N_MK[m][k] + alpha[k]) #话题zi=k的条件分布p(zi|z_i,w,alpha,beta)的分子部分\n",
" sums_v = 0\n",
" sums_k += N_MK[m][k] + alpha[k] #累计(nmk + alpha_k)在K个话题上的和\n",
" for t in range(V):\n",
" sums_v += N_KV[k][t] + beta[t] #累计(nkv + beta_v)在V个单词上的和\n",
" p_zk /= sums_v\n",
" p.append(p_zk)\n",
" p = p / sums_k\n",
" p = p / np.sum(p) #对条件分布p(zi|z_i,w,alpha,beta)进行归一化保证概率的总和为1\n",
" new_z = np.random.choice(a=K, p=p) #根据以上计算得到的概率进行抽样,得到新的话题\n",
" Z_MN[m][n] = new_z #更新当前位置处的话题为上面抽样得到的新话题\n",
" #更新计数\n",
" N_MK[m, new_z] += 1\n",
" N_M[m] += 1\n",
" N_KV[new_z, v] += 1\n",
" N_K[new_z] += 1\n",
" #算法20.2的步骤(4)利用得到的样本计数估计模型的参数theta和phi\n",
" theta = np.zeros((M, K))\n",
" phi = np.zeros((K, V))\n",
" for m in range(M):\n",
" sums_k = 0\n",
" for k in range(K):\n",
" theta[m, k] = N_MK[m][k] + alpha[k] #参数theta的分子部分\n",
" sums_k += theta[m, k] #累计(nmk + alpha_k)在K个话题上的和参数theta的分母部分\n",
" theta[m] /= sums_k #计算参数theta\n",
" for k in range(K):\n",
" sums_v = 0\n",
" for v in range(V):\n",
" phi[k, v] = N_KV[k][v] + beta[v] #参数phi的分子部分\n",
" sums_v += phi[k][v] #累计(nkv + beta_v)在V个单词上的和参数phi的分母部分\n",
" phi[k] /= sums_v #计算参数phi\n",
" return theta, phi\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" K = 5 #设定话题数为5\n",
" org_topics, text, words, alpha, beta = load_data('bbc_text.csv', K) #加载数据\n",
" print('Original Topics:')\n",
" print(org_topics) #打印原始的话题标签列表\n",
" start = time.time() #保存开始时间\n",
" iters = 10 #为了避免运行时间过长这里只迭代10次实际上10次是不够的要迭代足够的次数保证吉布斯抽样进入燃烧期这样得到的参数才能尽可能接近样本的实际概率分布\n",
" theta, phi = do_lda(text, words, alpha, beta, K, iters) #LDA的吉布斯抽样\n",
" #打印出每个话题zk条件下出现概率最大的前10个单词即P(wi|zk)在话题zk中最大的10个值对应的单词作为对话题zk的文本描述\n",
" for k in range(K):\n",
" sort_inds = np.argsort(phi[k])[::-1] #对话题zk条件下的P(wi|zk)的值进行降序排列后取出对应的索引值\n",
" topic = [] #定义一个空列表用于保存话题zk概率最大的前10个单词\n",
" for i in range(10):\n",
" topic.append(words[sort_inds[i]]) \n",
" topic = ' '.join(topic) #将10个单词以空格分隔构成对话题zk的文本表述\n",
" print('Topic {}: {}'.format(k+1, topic)) #打印话题zk\n",
" end = time.time()\n",
" print('Time:', end-start)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.04935605, 0.10338287, 0.06575088, 0.46867464, 0.31283557],\n",
" [0.18828892, 0.33184591, 0.36042376, 0.00247833, 0.11696308],\n",
" [0.64543178, 0.13184591, 0.06042376, 0.15962119, 0.00267737],\n",
" ...,\n",
" [0.41026611, 0.05564755, 0.31881135, 0.21280899, 0.002466 ],\n",
" [0.34581233, 0.01506225, 0.12993015, 0.06198299, 0.44721227],\n",
" [0.74515492, 0.00347293, 0.15499489, 0.09353762, 0.00283963]])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1.99768905e-02, 1.06570534e-02, 5.88055310e-03, ...,\n",
" 1.18881649e-03, 8.27178858e-09, 1.50724726e-03],\n",
" [3.50778918e-02, 1.05511733e-02, 8.79264523e-03, ...,\n",
" 3.00802648e-04, 9.01573088e-09, 9.01573088e-09],\n",
" [3.44183618e-02, 7.06465729e-03, 1.14162405e-02, ...,\n",
" 4.52128900e-04, 1.69555219e-04, 1.10105081e-08],\n",
" [1.81454758e-02, 3.13112016e-03, 1.39941969e-02, ...,\n",
" 1.18602254e-04, 1.99237185e-03, 9.24197417e-09],\n",
" [4.45921371e-02, 1.96021164e-02, 8.00255656e-03, ...,\n",
" 6.17454557e-09, 6.17454557e-09, 3.01086896e-04]])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"phi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}