查看原文
其他

kmeans算法python代码——可直接运行


在安装了相应依赖包情况下,以下代码可直接运行。


 1# -*- coding:utf-8 -*-
 2
 3import numpy as np
 4import random as rd
 5import matplotlib.pyplot as plt
 6import math
 7
 8def printLine():
 9   print '----------------------------------------------------------------------------'
10
11#计算聚类中心
12def cent(x):
13   return(sum(x)/len(x))
14
15#距离, 返回s,C,分别是距离平方和与聚类方案
16def f(center):
17    # c0 = []
18    # c1 = []
19    # c2 = []
20    c = [[] for i in range(k)]
21    D = np.arange(k*n).reshape(k,n)
22    d = np.array([center[i]-dat.T for i in range(k)])
23    for i in range(k):
24        D[i] = sum((d[i]**2).T)
25    for i in range(n):
26        ind = D.T[i].argmin()
27        c[ind].append(i)
28    C = [np.array([dat.T[i] for i in j]) for j in c]
29    print(c)
30    s = 0
31    for i in C:
32        s+=dist(i)
33    return(s,C)
34
35#计算各点到聚类中心的距离之和
36def dist(x):
37    #聚类中心
38    m0 = cent(x)
39    dis = sum(sum((x-m0)**2))
40    return dis
41
42def run():
43    # 存储距离和
44    s_sum = []
45    #---随机产生聚类中心----#
46    center = rd.sample(range(n),k)
47    center = np.array([dat.T[i] for i in center])
48    print '初始化聚类中心为:'.decode('utf-8')
49    print(center)
50    printLine()
51    #初始距离和
52    print '第1次计算!'.decode('utf-8')
53    dd,C = f(center)
54    s_sum.append(dd)
55    print ('距离和为'+str(dd)).decode('utf-8')
56    printLine()
57    print('第2次计算!'.decode('utf-8'))
58    center = [cent(i) for i in C]
59    Dd,C = f(center)
60    s_sum.append(Dd)
61    print ('距离和为'+str(Dd)).decode('utf-8')
62    # 前面已经计算2次了,所以这里从第三次开始计算
63    K = 3
64    while(K<n_max):
65       printLine()
66       #两次差值很小并且计算了一定次数
67       if(math.sqrt(abs(dd-Dd)) < 0 and K>20):
68          break;
69       print ('第'+str(K)+'次计算!').decode('utf-8')
70       dd = Dd
71       print ('距离和为'+str(dd)).decode('utf-8')
72       #当前聚类中心
73       center = [cent(i) for i in C]
74       Dd,C = f(center)
75       s_sum.append(Dd)
76       K+=1
77
78    #-----------------聚类结果可视化部分--------------------#
79    j = 0
80    for i in C:
81       if(j == 0):
82          plt.plot(i.T[0],i.T[1],'ro')
83       if(j == 1):
84          plt.plot(i.T[0],i.T[1],'b+')
85       if(j == 2):
86          plt.plot(i.T[0],i.T[1],'g*')
87       if(j == 3):
88          plt.plot(i.T[0],i.T[1], 'c<')
89       j+=1
90    plt.show()
91    x = range(len(s_sum))
92    plt.plot(x, s_sum)
93    plt.plot(x, s_sum, 'ro')
94    plt.show()
95
96
97print '==============================================================================='
98#数据
99dat = np.array([[14,22,15,20,30,20,32,13,23,20,21,22,23,24,35,18,20,31,14]
100    ,[15,28,18,30,35,15,30,15,25,23,24,25,26,27,30,15,24,33,12]])
101dat = np.random.randint(030, (240))
102print(dat)
103#=========================聚类中心======================#
104n = len(dat[0])
105N = len(dat)*n
106k = 4
107n_max = 50
108
109# 程序入口
110if __name__ == '__main__':
111    print '==============================================================================='
112    run()


可以通过修改k和n_max的值,改变聚类数量和测试样本数量。


点击【阅读原文】购买python spark大数据课程,限时优惠。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存