mirror of
https://gitee.com/TheAlgorithms/Statistical-Learning-Method_Code.git
synced 2025-01-03 17:32:21 +08:00
166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
|
# coding=utf-8
|
|||
|
# Author:Dodo
|
|||
|
# Date:2018-11-27
|
|||
|
# Email:lvtengchao@pku.edu.cn
|
|||
|
# Blog:www.pkudodo.com
|
|||
|
|
|||
|
'''
|
|||
|
数据集:Mnist
|
|||
|
训练集数量:60000
|
|||
|
测试集数量:10000
|
|||
|
------------------------------
|
|||
|
运行结果:
|
|||
|
正确率:98.91%
|
|||
|
运行时长:59s
|
|||
|
'''
|
|||
|
|
|||
|
import time
|
|||
|
import numpy as np
|
|||
|
|
|||
|
|
|||
|
def loadData(fileName):
|
|||
|
'''
|
|||
|
加载Mnist数据集
|
|||
|
:param fileName:要加载的数据集路径
|
|||
|
:return: list形式的数据集及标记
|
|||
|
'''
|
|||
|
# 存放数据及标记的list
|
|||
|
dataList = []; labelList = []
|
|||
|
# 打开文件
|
|||
|
fr = open(fileName, 'r')
|
|||
|
# 将文件按行读取
|
|||
|
for line in fr.readlines():
|
|||
|
# 对每一行数据按切割福','进行切割,返回字段列表
|
|||
|
curLine = line.strip().split(',')
|
|||
|
|
|||
|
# Mnsit有0-9是个标记,由于是二分类任务,所以将标记0的作为1,其余为0
|
|||
|
# 验证过<5为1 >5为0时正确率在90%左右,猜测是因为数多了以后,可能不同数的特征较乱,不能有效地计算出一个合理的超平面
|
|||
|
# 查看了一下之前感知机的结果,以5为分界时正确率81,重新修改为0和其余数时正确率98.91%
|
|||
|
# 看来如果样本标签比较杂的话,对于是否能有效地划分超平面确实存在很大影响
|
|||
|
if int(curLine[0]) == 0:
|
|||
|
labelList.append(1)
|
|||
|
else:
|
|||
|
labelList.append(0)
|
|||
|
#存放标记
|
|||
|
#[int(num) for num in curLine[1:]] -> 遍历每一行中除了以第一哥元素(标记)外将所有元素转换成int类型
|
|||
|
#[int(num)/255 for num in curLine[1:]] -> 将所有数据除255归一化(非必须步骤,可以不归一化)
|
|||
|
dataList.append([int(num)/255 for num in curLine[1:]])
|
|||
|
# dataList.append([int(num) for num in curLine[1:]])
|
|||
|
|
|||
|
#返回data和label
|
|||
|
return dataList, labelList
|
|||
|
|
|||
|
def predict(w, x):
|
|||
|
'''
|
|||
|
预测标签
|
|||
|
:param w:训练过程中学到的w
|
|||
|
:param x: 要预测的样本
|
|||
|
:return: 预测结果
|
|||
|
'''
|
|||
|
#dot为两个向量的点积操作,计算得到w * x
|
|||
|
wx = np.dot(w, x)
|
|||
|
#计算标签为1的概率
|
|||
|
#该公式参考“6.1.2 二项逻辑斯蒂回归模型”中的式6.5
|
|||
|
P1 = np.exp(wx) / (1 + np.exp(wx))
|
|||
|
#如果为1的概率大于0.5,返回1
|
|||
|
if P1 >= 0.5:
|
|||
|
return 1
|
|||
|
#否则返回0
|
|||
|
return 0
|
|||
|
|
|||
|
def logisticRegression(trainDataList, trainLabelList, iter = 200):
|
|||
|
'''
|
|||
|
逻辑斯蒂回归训练过程
|
|||
|
:param trainDataList:训练集
|
|||
|
:param trainLabelList: 标签集
|
|||
|
:param iter: 迭代次数
|
|||
|
:return: 习得的w
|
|||
|
'''
|
|||
|
#按照书本“6.1.2 二项逻辑斯蒂回归模型”中式6.5的规则,将w与b合在一起,
|
|||
|
#此时x也需要添加一维,数值为1
|
|||
|
#循环遍历每一个样本,并在其最后添加一个1
|
|||
|
for i in range(len(trainDataList)):
|
|||
|
trainDataList[i].append(1)
|
|||
|
|
|||
|
#将数据集由列表转换为数组形式,主要是后期涉及到向量的运算,统一转换成数组形式比较方便
|
|||
|
trainDataList = np.array(trainDataList)
|
|||
|
#初始化w,维数为样本x维数+1,+1的那一位是b,初始为0
|
|||
|
w = np.zeros(trainDataList.shape[1])
|
|||
|
|
|||
|
#设置步长
|
|||
|
h = 0.001
|
|||
|
|
|||
|
#迭代iter次进行随机梯度下降
|
|||
|
for i in range(iter):
|
|||
|
#每次迭代冲遍历一次所有样本,进行随机梯度下降
|
|||
|
for j in range(trainDataList.shape[0]):
|
|||
|
#随机梯度上升部分
|
|||
|
#在“6.1.3 模型参数估计”一章中给出了似然函数,我们需要极大化似然函数
|
|||
|
#但是似然函数由于有求和项,并不能直接对w求导得出最优w,所以针对似然函数求和
|
|||
|
#部分中每一项进行单独地求导w,得到针对该样本的梯度,并进行梯度上升(因为是
|
|||
|
#要求似然函数的极大值,所以是梯度上升,如果是极小值就梯度下降。梯度上升是
|
|||
|
#加号,下降是减号)
|
|||
|
#求和式中每一项单独对w求导结果为:xi * yi - (exp(w * xi) * xi) / (1 + exp(w * xi))
|
|||
|
#如果对于该求导式有疑问可查看我的博客 www.pkudodo.com
|
|||
|
|
|||
|
#计算w * xi,因为后式中要计算两次该值,为了节约时间这里提前算出
|
|||
|
#其实也可直接算出exp(wx),为了读者能看得方便一点就这么写了,包括yi和xi都提前列出了
|
|||
|
wx = np.dot(w, trainDataList[j])
|
|||
|
yi = trainLabelList[j]
|
|||
|
xi = trainDataList[j]
|
|||
|
#梯度上升
|
|||
|
w += h * (xi * yi - (np.exp(wx) * xi) / ( 1 + np.exp(wx)))
|
|||
|
|
|||
|
#返回学到的w
|
|||
|
return w
|
|||
|
|
|||
|
def test(testDataList, testLabelList, w):
|
|||
|
'''
|
|||
|
验证
|
|||
|
:param testDataList:测试集
|
|||
|
:param testLabelList: 测试集标签
|
|||
|
:param w: 训练过程中学到的w
|
|||
|
:return: 正确率
|
|||
|
'''
|
|||
|
|
|||
|
#与训练过程一致,先将所有的样本添加一维,值为1,理由请查看训练函数
|
|||
|
for i in range(len(testDataList)):
|
|||
|
testDataList[i].append(1)
|
|||
|
|
|||
|
#错误值计数
|
|||
|
errorCnt = 0
|
|||
|
#对于测试集中每一个测试样本进行验证
|
|||
|
for i in range(len(testDataList)):
|
|||
|
#如果标记与预测不一致,错误值加1
|
|||
|
if testLabelList[i] != predict(w, testDataList[i]):
|
|||
|
errorCnt += 1
|
|||
|
#返回准确率
|
|||
|
return 1 - errorCnt / len(testDataList)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
start = time.time()
|
|||
|
|
|||
|
# 获取训练集及标签
|
|||
|
print('start read transSet')
|
|||
|
trainData, trainLabel = loadData('../Mnist/mnist_train.csv')
|
|||
|
|
|||
|
# 获取测试集及标签
|
|||
|
print('start read testSet')
|
|||
|
testData, testLabel = loadData('../Mnist/mnist_test.csv')
|
|||
|
|
|||
|
# 开始训练,学习w
|
|||
|
print('start to train')
|
|||
|
w = logisticRegression(trainData, trainLabel)
|
|||
|
|
|||
|
#验证正确率
|
|||
|
print('start to test')
|
|||
|
accuracy = test(testData, testLabel, w)
|
|||
|
|
|||
|
# 打印准确率
|
|||
|
print('the accuracy is:', accuracy)
|
|||
|
# 打印时间
|
|||
|
print('time span:', time.time() - start)
|
|||
|
|