1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
| import numpy as np
import matplotlib.pyplot as plt
def distance(x1, x2):
"""欧式距离"""
return np.sqrt(np.sum(np.power(x1 - x2, 2)))
def sse(centroids1, centroids2):
return np.sum(np.sqrt(np.sum(np.power(centroids1 - centroids2, 2), axis=1)))
def update_centroid(centroids, data):
r, _ = data.shape
cluster_idxs = []
for i in range(len(centroids)):
cluster_idxs.append([])
for i in range(r):
ds = np.array([distance(data[i], centroid) for centroid in centroids])
sorted_idxs = np.argsort(ds)
cluster_idxs[sorted_idxs[0]].append(i)
new_centroids = []
for i, cluster_idx in enumerate(cluster_idxs):
if len(cluster_idx) == 0:
new_centroids.append(centroids[i])
else:
new_centroids.append(np.mean(data[cluster_idx], axis=0))
return np.array(new_centroids)
def initial_centroids(k, data):
r, _ = data.shape
idxs = np.arange(0, r)
np.random.shuffle(idxs)
return data[idxs[:k]]
def cluster(centroids, data):
r, _ = data.shape
cluster_idxs = []
for i in range(len(centroids)):
cluster_idxs.append([])
for i in range(r):
ds = np.array([distance(data[i], centroid) for centroid in centroids])
sorted_idxs = np.argsort(ds)
cluster_idxs[sorted_idxs[0]].append(i)
return cluster_idxs
data = np.random.uniform(5, 10, size=(400, 2))
k = 5
colors = ['#4e9e9d', '#86cc7f', '#506798', '#4f1b63', '#fbe85a']
tol = 1e-6
iteration = 12
plt.figure(figsize=(10, 4))
fig = plt.figure(figsize=(10, 15))
axes = fig.subplots(nrows=3, ncols=2)
centroids = initial_centroids(k, data)
i = 0
while iteration >= 0:
if iteration % 2 == 1:
cluster_idxs = cluster(centroids, data)
for color_idx, cluster_idx in enumerate(cluster_idxs):
fig.axes[i].scatter(data[cluster_idx][:,0], data[cluster_idx][:,1], c=colors[color_idx])
fig.axes[i].scatter(centroids[:,0], centroids[:,1], s=30, marker='*', c='red')
fig.axes[i].set_title('iter %d' % iteration)
i = i + 1
new_centroids = update_centroid(centroids, data)
if sse(new_centroids, centroids) <= tol:
centroids = new_centroids
break
centroids = new_centroids
iteration = iteration - 1
|