mirror of
https://gitee.com/TheAlgorithms/Statistical-Learning-Method_Code.git
synced 2025-01-03 01:12:20 +08:00
Update KNN.py
This commit is contained in:
parent
d19687aba3
commit
5efeaa9b8f
@ -120,7 +120,7 @@ def getClosest(trainDataMat, trainLabelMat, x, topK):
|
||||
return labelList.index(max(labelList))
|
||||
|
||||
|
||||
def test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):
|
||||
def model_test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):
|
||||
'''
|
||||
测试正确率
|
||||
:param trainDataArr:训练集数据集
|
||||
@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||
#获取测试集
|
||||
testDataArr, testLabelArr = loadData('../Mnist/mnist_test.csv')
|
||||
#计算测试集正确率
|
||||
accur = test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, 25)
|
||||
accur = model_test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, 25)
|
||||
#打印正确率
|
||||
print('accur is:%d'%(accur * 100), '%')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user