mirror of
https://gitee.com/TheAlgorithms/Statistical-Learning-Method_Code.git
synced 2025-01-03 01:12:20 +08:00
Update KNN.py
This commit is contained in:
parent
d19687aba3
commit
5efeaa9b8f
@ -120,7 +120,7 @@ def getClosest(trainDataMat, trainLabelMat, x, topK):
|
|||||||
return labelList.index(max(labelList))
|
return labelList.index(max(labelList))
|
||||||
|
|
||||||
|
|
||||||
def test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):
|
def model_test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):
|
||||||
'''
|
'''
|
||||||
测试正确率
|
测试正确率
|
||||||
:param trainDataArr:训练集数据集
|
:param trainDataArr:训练集数据集
|
||||||
@ -166,7 +166,7 @@ if __name__ == "__main__":
|
|||||||
#获取测试集
|
#获取测试集
|
||||||
testDataArr, testLabelArr = loadData('../Mnist/mnist_test.csv')
|
testDataArr, testLabelArr = loadData('../Mnist/mnist_test.csv')
|
||||||
#计算测试集正确率
|
#计算测试集正确率
|
||||||
accur = test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, 25)
|
accur = model_test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, 25)
|
||||||
#打印正确率
|
#打印正确率
|
||||||
print('accur is:%d'%(accur * 100), '%')
|
print('accur is:%d'%(accur * 100), '%')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user