K-means实现聚类

        最近的数据挖掘讲座课上讲到了聚类算法K-means,数据挖掘在近一两年内很火,本人对这方面的技术也比较感兴趣,索性写篇博客记录一下,同时加深一下理解。

简介

        k-平均算法源于信号处理中的一种向量量化方法,现在则更多地作为一种聚类分析方法流行于数据挖掘领域。k-平均聚类的目的是:把 个点(可以是样本的一次观察或一个实例)划分到k个聚类中,使得每个点都属于离他最近的均值(此即聚类中心)对应的聚类,以之作为聚类的标准。这个问题将归结为一个把数据空间划分为Voronoi cells的问题。
        这个问题在计算上是困难的(NP困难),不过存在高效的启发式算法。一般情况下,都使用效率比较高的启发式算法,它们能够快速收敛于一个局部最优解。这些算法通常类似于通过迭代优化方法处理高斯混合分布的最大期望算法(EM算法)。而且,它们都使用聚类中心来为数据建模;然而k-平均聚类倾向于在可比较的空间范围内寻找聚类,期望-最大化技术却允许聚类有不同的形状。

                                                                                                   
中文名 K均值 优   点 确定的K个划分到达平方误差最小
外文名 K-means 缺   点 K 值的选定是非常难以估计的
相似度测度 欧氏距离

算法描述

        已知观测集,其中每个观测都是一个d-维实向量,k-平均聚类要把这n个观测划分到k个集合中(k≤n),使得组内平方和(WCSS within-cluster sum of squares)最小。换句话说,它的目标是找到使得下式满足的聚类




其中中所有点的均值。


算法实现

        我是采用Java实现的,不得不说多维数组表示起来各种ArrayList>的嵌套,写起来十分麻烦,而且可读性差(果断C/C++更适合写算法啊,Java没有 #define 很难受)。
        好了,直接上代码吧,代码中有比较详细的注释,可读性还是比较高的。另外,由于算法本身效率不高(需要根据具体情况优化)和测试数据比较多,且为浮点数,所以在这里限制了算法秩代的次数为100,000次,防止执行时间过长。

K-means:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import java.util.ArrayList;
public class Kmeans {
int dimension; //维数
ArrayList<Float> point; //集合中的点
ArrayList<ArrayList<Float>> clust; //簇
ArrayList<ArrayList<ArrayList<Float>>> set; //集合
int K; //需要聚类成簇的个数
ArrayList<ArrayList<Float>> meanValues; //均值点
ArrayList<ArrayList<Float>> data; //初始数据
public Kmeans(int dimension, int K){
this.dimension = dimension;
this.K = K;
set = new ArrayList<>();
for(int i=0; i<K; i++){
set.add(new ArrayList<ArrayList<Float>>());
}
}
void initData(ArrayList<ArrayList<Float>> data){
this.data = data;
randomInitmManValues();
}
/**
* 随机初始化均值点(分别取下标为0,50,100的点作为三个初始均值点)
*/
private void randomInitmManValues(){
meanValues = new ArrayList<>();
int index=0;
for(int i=0; i<K; i++){
meanValues.add(data.get(index));
index += 50;
}
}
/**
* 求欧式距离,判断后放入所属的簇
* @param dataPoint 需要计算欧式距离的数据点
* @param index 该数据点所属的簇在原set中的序号
*/
private void putPointIntoClust(ArrayList<Float> dataPoint, int index){
double dis = Double.MAX_VALUE;
int idex = -1;
for(int i=0; i<K; i++){ //分别求出到K个点的距离
point = meanValues.get(i);
double sum = 0;
for(int j=0; j<point.size(); j++){ //算维度差的平方和
sum += (dataPoint.get(j) - point.get(j)) * (dataPoint.get(j) - point.get(j));
}
double sqrtSum = Math.sqrt(sum);
if(sqrtSum < dis){
dis = sqrtSum;
idex = i; //记录最近的平均点所在簇的index
}
}
if(idex != -1 && !set.get(idex).contains(dataPoint)){ //如果第i个簇中不包含这个数据点
set.get(index).remove(dataPoint); //从原簇中删除这个数据点
set.get(idex).add(dataPoint); //将该数据点放入第i个簇中
}
}
/**
* 求出簇的均值点
* @param index 所计算那个簇对于set的index
* @return 均值点
*/
private ArrayList<Float> getMeanValue(int index){
ArrayList<Float> meanValue = new ArrayList<>();
for(int i=0; i<point.size(); i++){ //初始化 meanValue各个维度的值
meanValue.add((float) 0);
}
for(ArrayList<Float> point : set.get(index)){ //对簇中的各个点求和
for(int i=0; i<point.size(); i++){
meanValue.set(i, meanValue.get(i) + point.get(i));
}
}
for(int i=0; i<meanValue.size(); i++){ //求该簇的均值点
meanValue.set(i, meanValue.get(i) / set.get(index).size());
}
return meanValue;
}
/**
* K-means实现
*/
void k_means(){
int mod = data.size()%K==0 ? data.size()/K : data.size()/K+1;
int idex = 0;
for(int j=0; j<data.size(); j++){ //把data中的数据放入平均的放入簇中
set.get(idex).add(data.get(j));
if(j+1 != data.size() && (j+1) % mod == 0) idex++;
}
int times = 0; //限定循环的次数,超过1000次则停止
while(true){ //K-means算法
ArrayList<ArrayList<Float>> meanValues = new ArrayList<>();
for(int i=0; i<set.size(); i++){
for(int j=0; j<set.get(i).size(); j++){
putPointIntoClust(set.get(i).get(j), i);
}
meanValues.add(getMeanValue(i));
}
if(meanValues.equals(this.meanValues) || times >= 100000){
break;
}
times++;
System.out.println("均值点发生移动,继续进行计算!");
}
}
/**
* 打印set
*/
void printf(){
for(int i=0; i<set.size(); i++){
System.out.println("第"+i+"个簇包含的点:");
for(int j=0; j<set.get(i).size(); j++){
System.out.println((j+1) + " "+ set.get(i).get(j));
}
}
}
}

Test程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import java.util.ArrayList;
import java.util.Scanner;
public class Test {
public static void main(String[] ss){
ArrayList<ArrayList<Float>> data = new ArrayList<ArrayList<Float>>();
Scanner scanner = new Scanner(System.in);
for(int i=0; i<150; i++){
data.add(new ArrayList<Float>());
for(int j=0; j<4; j++){
data.get(i).add(scanner.nextFloat());
}
}
scanner.close();
Kmeans kmeans = new Kmeans(4, 3);
kmeans.initData(data);
kmeans.k_means();
kmeans.printf();
}
}

测试数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
5.1 3.5 1.4 0.2
4.9 3.0 1.4 0.2
4.7 3.2 1.3 0.2
4.6 3.1 1.5 0.2
5.0 3.6 1.4 0.2
5.4 3.9 1.7 0.4
4.6 3.4 1.4 0.3
5.0 3.4 1.5 0.2
4.4 2.9 1.4 0.2
4.9 3.1 1.5 0.1
5.4 3.7 1.5 0.2
4.8 3.4 1.6 0.2
4.8 3.0 1.4 0.1
4.3 3.0 1.1 0.1
5.8 4.0 1.2 0.2
5.7 4.4 1.5 0.4
5.4 3.9 1.3 0.4
5.1 3.5 1.4 0.3
5.7 3.8 1.7 0.3
5.1 3.8 1.5 0.3
5.4 3.4 1.7 0.2
5.1 3.7 1.5 0.4
4.6 3.6 1.0 0.2
5.1 3.3 1.7 0.5
4.8 3.4 1.9 0.2
5.0 3.0 1.6 0.2
5.0 3.4 1.6 0.4
5.2 3.5 1.5 0.2
5.2 3.4 1.4 0.2
4.7 3.2 1.6 0.2
4.8 3.1 1.6 0.2
5.4 3.4 1.5 0.4
5.2 4.1 1.5 0.1
5.5 4.2 1.4 0.2
4.9 3.1 1.5 0.2
5.0 3.2 1.2 0.2
5.5 3.5 1.3 0.2
4.9 3.6 1.4 0.1
4.4 3.0 1.3 0.2
5.1 3.4 1.5 0.2
5.0 3.5 1.3 0.3
4.5 2.3 1.3 0.3
4.4 3.2 1.3 0.2
5.0 3.5 1.6 0.6
5.1 3.8 1.9 0.4
4.8 3.0 1.4 0.3
5.1 3.8 1.6 0.2
4.6 3.2 1.4 0.2
5.3 3.7 1.5 0.2
5.0 3.3 1.4 0.2
7.0 3.2 4.7 1.4
6.4 3.2 4.5 1.5
6.9 3.1 4.9 1.5
5.5 2.3 4.0 1.3
6.5 2.8 4.6 1.5
5.7 2.8 4.5 1.3
6.3 3.3 4.7 1.6
4.9 2.4 3.3 1.0
6.6 2.9 4.6 1.3
5.2 2.7 3.9 1.4
5.0 2.0 3.5 1.0
5.9 3.0 4.2 1.5
6.0 2.2 4.0 1.0
6.1 2.9 4.7 1.4
5.6 2.9 3.6 1.3
6.7 3.1 4.4 1.4
5.6 3.0 4.5 1.5
5.8 2.7 4.1 1.0
6.2 2.2 4.5 1.5
5.6 2.5 3.9 1.1
5.9 3.2 4.8 1.8
6.1 2.8 4.0 1.3
6.3 2.5 4.9 1.5
6.1 2.8 4.7 1.2
6.4 2.9 4.3 1.3
6.6 3.0 4.4 1.4
6.8 2.8 4.8 1.4
6.7 3.0 5.0 1.7
6.0 2.9 4.5 1.5
5.7 2.6 3.5 1.0
5.5 2.4 3.8 1.1
5.5 2.4 3.7 1.0
5.8 2.7 3.9 1.2
6.0 2.7 5.1 1.6
5.4 3.0 4.5 1.5
6.0 3.4 4.5 1.6
6.7 3.1 4.7 1.5
6.3 2.3 4.4 1.3
5.6 3.0 4.1 1.3
5.5 2.5 4.0 1.3
5.5 2.6 4.4 1.2
6.1 3.0 4.6 1.4
5.8 2.6 4.0 1.2
5.0 2.3 3.3 1.0
5.6 2.7 4.2 1.3
5.7 3.0 4.2 1.2
5.7 2.9 4.2 1.3
6.2 2.9 4.3 1.3
5.1 2.5 3.0 1.1
5.7 2.8 4.1 1.3
6.3 3.3 6.0 2.5
5.8 2.7 5.1 1.9
7.1 3.0 5.9 2.1
6.3 2.9 5.6 1.8
6.5 3.0 5.8 2.2
7.6 3.0 6.6 2.1
4.9 2.5 4.5 1.7
7.3 2.9 6.3 1.8
6.7 2.5 5.8 1.8
7.2 3.6 6.1 2.5
6.5 3.2 5.1 2.0
6.4 2.7 5.3 1.9
6.8 3.0 5.5 2.1
5.7 2.5 5.0 2.0
5.8 2.8 5.1 2.4
6.4 3.2 5.3 2.3
6.5 3.0 5.5 1.8
7.7 3.8 6.7 2.2
7.7 2.6 6.9 2.3
6.0 2.2 5.0 1.5
6.9 3.2 5.7 2.3
5.6 2.8 4.9 2.0
7.7 2.8 6.7 2.0
6.3 2.7 4.9 1.8
6.7 3.3 5.7 2.1
7.2 3.2 6.0 1.8
6.2 2.8 4.8 1.8
6.1 3.0 4.9 1.8
6.4 2.8 5.6 2.1
7.2 3.0 5.8 1.6
7.4 2.8 6.1 1.9
7.9 3.8 6.4 2.0
6.4 2.8 5.6 2.2
6.3 2.8 5.1 1.5
6.1 2.6 5.6 1.4
7.7 3.0 6.1 2.3
6.3 3.4 5.6 2.4
6.4 3.1 5.5 1.8
6.0 3.0 4.8 1.8
6.9 3.1 5.4 2.1
6.7 3.1 5.6 2.4
6.9 3.1 5.1 2.3
5.8 2.7 5.1 1.9
6.8 3.2 5.9 2.3
6.7 3.3 5.7 2.5
6.7 3.0 5.2 2.3
6.3 2.5 5.0 1.9
6.5 3.0 5.2 2.0
6.2 3.4 5.4 2.3
5.9 3.0 5.1 1.8

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
第0个簇包含的点:
1 [5.1, 3.5, 1.4, 0.2]
2 [4.9, 3.0, 1.4, 0.2]
3 [4.7, 3.2, 1.3, 0.2]
4 [4.6, 3.1, 1.5, 0.2]
5 [5.0, 3.6, 1.4, 0.2]
6 [5.4, 3.9, 1.7, 0.4]
7 [4.6, 3.4, 1.4, 0.3]
8 [5.0, 3.4, 1.5, 0.2]
9 [4.4, 2.9, 1.4, 0.2]
10 [4.9, 3.1, 1.5, 0.1]
11 [5.4, 3.7, 1.5, 0.2]
12 [4.8, 3.4, 1.6, 0.2]
13 [4.8, 3.0, 1.4, 0.1]
14 [4.3, 3.0, 1.1, 0.1]
15 [5.8, 4.0, 1.2, 0.2]
16 [5.7, 4.4, 1.5, 0.4]
17 [5.4, 3.9, 1.3, 0.4]
18 [5.1, 3.5, 1.4, 0.3]
19 [5.7, 3.8, 1.7, 0.3]
20 [5.1, 3.8, 1.5, 0.3]
21 [5.4, 3.4, 1.7, 0.2]
22 [5.1, 3.7, 1.5, 0.4]
23 [4.6, 3.6, 1.0, 0.2]
24 [5.1, 3.3, 1.7, 0.5]
25 [4.8, 3.4, 1.9, 0.2]
26 [5.0, 3.0, 1.6, 0.2]
27 [5.0, 3.4, 1.6, 0.4]
28 [5.2, 3.5, 1.5, 0.2]
29 [5.2, 3.4, 1.4, 0.2]
30 [4.7, 3.2, 1.6, 0.2]
31 [4.8, 3.1, 1.6, 0.2]
32 [5.4, 3.4, 1.5, 0.4]
33 [5.2, 4.1, 1.5, 0.1]
34 [5.5, 4.2, 1.4, 0.2]
35 [4.9, 3.1, 1.5, 0.2]
36 [5.0, 3.2, 1.2, 0.2]
37 [5.5, 3.5, 1.3, 0.2]
38 [4.9, 3.6, 1.4, 0.1]
39 [4.4, 3.0, 1.3, 0.2]
40 [5.1, 3.4, 1.5, 0.2]
41 [5.0, 3.5, 1.3, 0.3]
42 [4.5, 2.3, 1.3, 0.3]
43 [4.4, 3.2, 1.3, 0.2]
44 [5.0, 3.5, 1.6, 0.6]
45 [5.1, 3.8, 1.9, 0.4]
46 [4.8, 3.0, 1.4, 0.3]
47 [5.1, 3.8, 1.6, 0.2]
48 [4.6, 3.2, 1.4, 0.2]
49 [5.3, 3.7, 1.5, 0.2]
50 [5.0, 3.3, 1.4, 0.2]
51 [4.9, 2.4, 3.3, 1.0]
52 [5.0, 2.3, 3.3, 1.0]
53 [5.1, 2.5, 3.0, 1.1]
第1个簇包含的点:
1 [7.0, 3.2, 4.7, 1.4]
2 [6.4, 3.2, 4.5, 1.5]
3 [6.9, 3.1, 4.9, 1.5]
4 [5.5, 2.3, 4.0, 1.3]
5 [6.5, 2.8, 4.6, 1.5]
6 [5.7, 2.8, 4.5, 1.3]
7 [6.3, 3.3, 4.7, 1.6]
8 [6.6, 2.9, 4.6, 1.3]
9 [5.2, 2.7, 3.9, 1.4]
10 [5.0, 2.0, 3.5, 1.0]
11 [5.9, 3.0, 4.2, 1.5]
12 [6.0, 2.2, 4.0, 1.0]
13 [6.1, 2.9, 4.7, 1.4]
14 [5.6, 2.9, 3.6, 1.3]
15 [6.7, 3.1, 4.4, 1.4]
16 [5.6, 3.0, 4.5, 1.5]
17 [5.8, 2.7, 4.1, 1.0]
18 [6.2, 2.2, 4.5, 1.5]
19 [5.6, 2.5, 3.9, 1.1]
20 [5.9, 3.2, 4.8, 1.8]
21 [6.1, 2.8, 4.0, 1.3]
22 [6.3, 2.5, 4.9, 1.5]
23 [6.1, 2.8, 4.7, 1.2]
24 [6.4, 2.9, 4.3, 1.3]
25 [6.6, 3.0, 4.4, 1.4]
26 [6.8, 2.8, 4.8, 1.4]
27 [6.7, 3.0, 5.0, 1.7]
28 [6.0, 2.9, 4.5, 1.5]
29 [5.7, 2.6, 3.5, 1.0]
30 [5.5, 2.4, 3.8, 1.1]
31 [5.5, 2.4, 3.7, 1.0]
32 [5.8, 2.7, 3.9, 1.2]
33 [6.0, 2.7, 5.1, 1.6]
34 [5.4, 3.0, 4.5, 1.5]
35 [6.0, 3.4, 4.5, 1.6]
36 [6.7, 3.1, 4.7, 1.5]
37 [6.3, 2.3, 4.4, 1.3]
38 [5.6, 3.0, 4.1, 1.3]
39 [5.5, 2.5, 4.0, 1.3]
40 [5.5, 2.6, 4.4, 1.2]
41 [6.1, 3.0, 4.6, 1.4]
42 [5.8, 2.6, 4.0, 1.2]
43 [5.6, 2.7, 4.2, 1.3]
44 [5.7, 3.0, 4.2, 1.2]
45 [5.7, 2.9, 4.2, 1.3]
46 [6.2, 2.9, 4.3, 1.3]
47 [5.7, 2.8, 4.1, 1.3]
48 [4.9, 2.5, 4.5, 1.7]
49 [6.5, 3.2, 5.1, 2.0]
50 [6.0, 2.2, 5.0, 1.5]
51 [6.3, 2.7, 4.9, 1.8]
52 [6.2, 2.8, 4.8, 1.8]
53 [7.2, 3.0, 5.8, 1.6]
54 [6.3, 2.8, 5.1, 1.5]
55 [6.0, 3.0, 4.8, 1.8]
56 [6.9, 3.1, 5.1, 2.3]
57 [6.3, 2.5, 5.0, 1.9]
58 [6.1, 3.0, 4.9, 1.8]
59 [6.5, 3.0, 5.2, 2.0]
第2个簇包含的点:
1 [6.3, 3.3, 6.0, 2.5]
2 [5.8, 2.7, 5.1, 1.9]
3 [7.1, 3.0, 5.9, 2.1]
4 [6.3, 2.9, 5.6, 1.8]
5 [6.5, 3.0, 5.8, 2.2]
6 [7.6, 3.0, 6.6, 2.1]
7 [7.3, 2.9, 6.3, 1.8]
8 [6.7, 2.5, 5.8, 1.8]
9 [7.2, 3.6, 6.1, 2.5]
10 [6.4, 2.7, 5.3, 1.9]
11 [6.8, 3.0, 5.5, 2.1]
12 [5.7, 2.5, 5.0, 2.0]
13 [5.8, 2.8, 5.1, 2.4]
14 [6.4, 3.2, 5.3, 2.3]
15 [6.5, 3.0, 5.5, 1.8]
16 [7.7, 3.8, 6.7, 2.2]
17 [7.7, 2.6, 6.9, 2.3]
18 [6.9, 3.2, 5.7, 2.3]
19 [5.6, 2.8, 4.9, 2.0]
20 [7.7, 2.8, 6.7, 2.0]
21 [6.7, 3.3, 5.7, 2.1]
22 [7.2, 3.2, 6.0, 1.8]
23 [6.4, 2.8, 5.6, 2.1]
24 [7.4, 2.8, 6.1, 1.9]
25 [7.9, 3.8, 6.4, 2.0]
26 [6.4, 2.8, 5.6, 2.2]
27 [6.1, 2.6, 5.6, 1.4]
28 [7.7, 3.0, 6.1, 2.3]
29 [6.3, 3.4, 5.6, 2.4]
30 [6.4, 3.1, 5.5, 1.8]
31 [6.9, 3.1, 5.4, 2.1]
32 [6.7, 3.1, 5.6, 2.4]
33 [5.8, 2.7, 5.1, 1.9]
34 [6.8, 3.2, 5.9, 2.3]
35 [6.7, 3.3, 5.7, 2.5]
36 [6.7, 3.0, 5.2, 2.3]
37 [6.2, 3.4, 5.4, 2.3]
38 [5.9, 3.0, 5.1, 1.8]

结果可能会存在一定误差。