2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > 机器学习——k近邻算法——性别预测

机器学习——k近邻算法——性别预测

时间:2020-01-29 11:17:17

相关推荐

机器学习——k近邻算法——性别预测

假设用人的身高和体重来预测人的性别,下表是一组采集数据

有一个预测对象,身高155cm,体重70千克,问该对象的标签是什么?要求用k近邻算法预测,距离采用欧氏距,假设k=3。

1.画出近邻的散点图

我们已经知道k-近邻算法根据特征比较,然后提取样本集中特征最相似数据(最邻近)的分类标签。那么我们先把图画出来分析测量距离。

2.性别分类

我们可以从散点图大致推断,这个预测五角星点标记的性别可能属于男生,因为距离已知的那两个男的点更近。

k-近邻算法用什么方法进行判断呢?没错,就是欧氏距离。这个性别分类的例子有2个特征,也就是在2维实数向量空间,可以欧氏距离进行测量

1.计算已知类别数据集中的点与当前点之间的距离

通过计算,我们可以得到如下结果(约等于):

男1= 6.70

男2= 15.52

男3= 31.30

男4 = 37.36

女1= 21.00

女2= 13.60

女3= 25.17

女4= 16.27

女5= 24.04

2.按照距离递增次序排序;

男1<女2<男2<女4<女1<女5<女3<男3<男4

确定前k个点所在类别的出现频率;

3.K=3,既取最近的三个邻点,两男一女

4.返回前k个点所出现频率最高的类别作为当前点的预测分类。

男多女少预测对象为男生。

这个判别过程就是k-近邻算法。

3.用jupyter来实现

import numpy as npimport matplotlib.pyplot as pltimport pandas as pdurl="sg.csv"dataset=pd.read_csv(url)X=dataset.iloc[:,:-2].valuesy=dataset.iloc[:,2].valuesfrom sklearn.neighbors import KNeighborsClassifierclassifier = KNeighborsClassifier()classifier.fit(X, y)y_pred=classifier.predict(X)knn = KNeighborsClassifier(n_neighbors=3)knn.fit(X, y)cat=np.array([[155,70]])print(knn.predict(cat))

画图的代码

# scatter画出来的是散点图, 取数据使用 .values,二维数组中, 一维全部取出, 二维取0,表示出来就是[:,0]plt.scatter(X[:,0],X[:,1])# scatter可以有一些属性, 下边的color可以自定义显示的颜色plt.scatter(cat[:,0],cat[:,1],color='red')

让图更美观可添加添加样式:

import numpy as npimport matplotlib.pyplot as pltimport pandas as pdplt.figure(figsize=(5, 5))x1=np.array([[158,64],[170,66],[183,84],[191,80]])plt.scatter(x1[:,0],x1[:,1],color='g',marker='s',s=90,label='男')x2=np.array([[155,49],[163,59],[180,67],[158,54],[178,77]])plt.scatter(x2[:,0],x2[:,1],color='m',marker='^',s=100,label='女')cat=np.array([[155,70]])plt.scatter(cat[:,0],cat[:,1],color='red',marker='*',s=200,label='预测')plt.legend(('男','女','预测') ,loc = 'best')plt.legend()# 用黑体显示中文plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置x,y,标题plt.title("性别分类")plt.ylabel('体重/kg')plt.xlabel('身高/cm')plt.grid()plt.savefig('squares.png',bbox_inches='tight')

求距离的代码:

i=0errer=[]while i<9:z, c = 155, 70a = int(input("输入身高:"))b = int(input("输入体重:"))d = math.sqrt((z-a)** 2 + (c-b) ** 2)print(d)errer.append(d)i+=1errer.sort()print(errer)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。