Statistical-Learning-Method.../LSA/LSA.ipynb

172 lines
6.8 KiB
Plaintext
Raw Normal View History

2021-01-26 16:52:19 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Topics:\n",
"['tech', 'business', 'sport', 'entertainment', 'politics']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\zengh\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:59: ComplexWarning: Casting complex values to real discards the imaginary part\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated Topics:\n",
"Topic 1: said people would music blair government best year film howard\n",
"Topic 2: said would labour party people last could years kilroysilk show\n",
"Topic 3: music microsoft year best urban industry record software email think\n",
"Topic 4: wales first games lord government play house public control prime\n",
"Topic 5: said mobile england people phone dallaglio rugby blair election would\n",
"Time: 212.96439409255981\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import string\n",
"from nltk.corpus import stopwords\n",
"import time \n",
"\n",
"\n",
"#定义加载数据的函数\n",
"def load_data(file):\n",
" '''\n",
" INPUT:\n",
" file - (str) 数据文件的路径\n",
" \n",
" OUTPUT:\n",
" org_topics - (list) 原始话题标签列表\n",
" text - (list) 文本列表\n",
" words - (list) 单词列表\n",
" \n",
" '''\n",
" df = pd.read_csv(file) #读取文件\n",
" org_topics = df['category'].unique().tolist() #保存文本原始的话题标签\n",
" df.drop('category', axis=1, inplace=True)\n",
" n = df.shape[0] #n为文本数量\n",
" text = []\n",
" words = []\n",
" for i in df['text'].values:\n",
" t = i.translate(str.maketrans('', '', string.punctuation)) #去除文本中的标点符号\n",
" t = [j for j in t.split() if j not in stopwords.words('english')] #去除文本中的停止词\n",
" t = [j for j in t if len(j) > 3] #长度小于等于3的单词大多是无意义的直接去除\n",
" text.append(t) #将处理后的文本保存到文本列表中\n",
" words.extend(set(t)) #将文本中所包含的单词保存到单词列表中\n",
" words = list(set(words)) #去除单词列表中的重复单词\n",
" return org_topics, text, words\n",
"\n",
"\n",
"#定义构建单词-文本矩阵的函数这里矩阵的每一项表示单词在文本中的出现频次也可以用TF-IDF来表示\n",
"def frequency_counter(text, words):\n",
" '''\n",
" INPUT:\n",
" text - (list) 文本列表\n",
" words - (list) 单词列表\n",
" \n",
" OUTPUT:\n",
" X - (array) 单词-文本矩阵\n",
" \n",
" '''\n",
" X = np.zeros((len(words), len(text))) #定义m*n的矩阵其中m为单词列表中的单词个数n为文本个数\n",
" for i in range(len(text)):\n",
" t = text[i] #读取文本列表中的第i条文本\n",
" for w in t:\n",
" ind = words.index(w) #取出第i条文本中的第t个单词在单词列表中的索引\n",
" X[ind][i] += 1 #对应位置的单词出现频次加一\n",
" return X\n",
"\n",
"\n",
"#定义潜在语义分析函数\n",
"def do_lsa(X, k, words):\n",
" '''\n",
" INPUT:\n",
" X - (array) 单词-文本矩阵\n",
" k - (int) 设定的话题数\n",
" words - (list) 单词列表\n",
" \n",
" OUTPUT:\n",
" topics - (list) 生成的话题列表\n",
" \n",
" '''\n",
" w, v = np.linalg.eig(np.matmul(X.T, X)) #计算Sx的特征值和特征向量其中Sx=X.T*XSx的特征值w即为X的奇异值分解的奇异值v即为对应的奇异向量\n",
" sort_inds = np.argsort(w)[::-1] #对特征值降序排列后取出对应的索引值\n",
" w = np.sort(w)[::-1] #对特征值降序排列\n",
" V_T = [] #用来保存矩阵V的转置\n",
" for ind in sort_inds:\n",
" V_T.append(v[ind]/np.linalg.norm(v[ind])) #将降序排列后各特征值对应的特征向量单位化后保存到V_T中\n",
" V_T = np.array(V_T) #将V_T转换为数组方便之后的操作\n",
" Sigma = np.diag(np.sqrt(w)) #将特征值数组w转换为对角矩阵即得到SVD分解中的Sigma\n",
" U = np.zeros((len(words), k)) #用来保存SVD分解中的矩阵U\n",
" for i in range(k):\n",
" ui = np.matmul(X, V_T.T[:, i]) / Sigma[i][i] #计算矩阵U的第i个列向量\n",
" U[:, i] = ui #保存到矩阵U中\n",
" topics = [] #用来保存k个话题\n",
" for i in range(k):\n",
" inds = np.argsort(U[:, i])[::-1] #U的每个列向量表示一个话题向量话题向量的长度为m其中每个值占向量值之和的比重表示对应单词在当前话题中所占的比重这里对第i个话题向量的值降序排列后取出对应的索引值\n",
" topic = [] #用来保存第i个话题\n",
" for j in range(10):\n",
" topic.append(words[inds[j]]) #根据索引inds取出当前话题中比重最大的10个单词作为第i个话题\n",
" topics.append(' '.join(topic)) #保存话题i\n",
" return topics\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" org_topics, text, words = load_data('bbc_text.csv') #加载数据\n",
" print('Original Topics:')\n",
" print(org_topics) #打印原始的话题标签列表\n",
" start = time.time() #保存开始时间\n",
" X = frequency_counter(text, words) #构建单词-文本矩阵\n",
" k = 5 #设定话题数为5\n",
" topics = do_lsa(X, k, words) #进行潜在语义分析\n",
" print('Generated Topics:')\n",
" for i in range(k):\n",
" print('Topic {}: {}'.format(i+1, topics[i])) #打印分析后得到的每个话题\n",
" end = time.time() #保存结束时间\n",
" print('Time:', end-start)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}