KMeans and MeanShift Clustering in Python


This article is about clustering using Python. In this article, we will look into two different methods of clustering. The first is KMeans clustering and the second is MeanShift clustering. KMeans clustering is a data mining application that partitions n observations into k clusters. Each observation belongs to the cluster with the nearest mean. In the KMeans clustering, you can specify the number of clusters to be generated, whereas in the MeanShift clustering, the number of clusters is automatically detected based on the number of density centers found in the data. The MeanShift algorithm shifts data points iteratively towards the mode, which is the highest density of data points. It is also called the mode-seeking algorithm.


The KMeans clustering can be achieved using the KMeans class in sklearn.cluster. Some of the parameters of KMeans are as follows:
  • n_clusters: The number of clusters as well as centroids to be generated. Default is 8.
  • n_jobs: The number of jobs to be run in parallel. -1 means to use all processors. Default is None.
  • n_init: The number of times the algorithm should run with different centroid seeds. Default is 10.
  • verbose: Displays information about the estimation if set to 1.
The MeanShift clustering can be achieved using the MeanShift class in sklearn.cluster. Some of the parameters of MeanShift are as follows:
  • n_jobs: The number of jobs to be run in parallel. -1 means to use all processors. Default is None.
  • bandwidth: The bandwidth to be used. If not specified, it is estimated using sklearn.estimate_bandwidth.
  • verbose: Displays information about the estimation if set to 1.
To demonstrate clustering, we can use the sample data provided by the iris dataset in sklearn.cluster package. The iris dataset consists of 150 samples (50 each) of 3 types of iris flowers (Setosa, Versicolor and Virginica) stored as a 150x4 numpy.ndarray. The rows represent the samples and the columns represent the Sepal Length, Sepal Width, Petal Length and Petal Width.

Using the Code

To implement clustering, we can use the sample data provided by the iris dataset.
First, we will see the implementation of the KMeans clustering.
We can load the iris dataset as follows:
  1. from sklearn import datasets  
  2. iris=datasets.load_iris()  
Then, we need to extract the sepal and petal data as follows:
Then, we create two KMeans objects and fit the sepal and petal data as follows:
  1. from sklearn.cluster import KMeans  
  2. km1=KMeans(n_clusters=3,n_jobs=-1)  
  4. km2=KMeans(n_clusters=3,n_jobs=-1)  
The next step is to determine the centroids and labels of the sepals and petals.
  1. centroids_sepals=km1.cluster_centers_  
  2. labels_sepals=km1.labels_  
  3. centroids_petals=km2.cluster_centers_  
  4. labels_petals=km2.labels_
In order to visualize the clusters, we can create scatter plots representing the sepal and petal clusters.
For that, first we create a figure object as follows:
  1. import matplotlib.pyplot as plt  
  2. from mpl_toolkits.mplot3d import Axes3D  
  3. fig=plt.figure()
We can create four subplots to show the sepal data in two dimensions and three dimensions. The subplots are created as a 2 by 2 matrix with the first row representing the sepal information and the second row representing the petal information. The first column of each row shows a 2-dimensional scatter chart and the second column shows a 3-dimensional scatter chart. The first two digits of the first parameter of the add_subplot() function represent the number of rows and number of columns and the third digit represents the sequence number of the current subplot. The second (optional) parameter represents the projection mode.
  1. ax1=fig.add_subplot(221)  
  2. ax2=fig.add_subplot(222,projection="3d")  
  3. ax3=fig.add_subplot(223)  
  4. ax4=fig.add_subplot(224,projection="3d")
To plot the scatter chart (data and centroids), we can use the following code:
  1. ax1.scatter(sepal_data[:,0],sepal_data[:,1],c=labels_sepals,s=50)  
  2. ax1.scatter(centroids_sepals[:,0],centroids_sepals[:,1],c="red",s=100)  
  3. ax2.scatter(sepal_data[:,0],sepal_data[:,1],c=labels_sepals,s=50)  
  4. ax2.scatter(centroids_sepals[:,0],centroids_sepals[:,1],c="red",s=100)  
  5. ax3.scatter(petal_data[:,0],petal_data[:,1],c=labels_petals,s=50)  
  6. ax3.scatter(centroids_petals[:,0],centroids_petals[:,1],c="red",s=100)  
  7. ax4.scatter(petal_data[:,0],petal_data[:,1],c=labels_petals,s=50)  
  8. ax4.scatter(centroids_petals[:,0],centroids_petals[:,1],c="red",s=100)   
The labels for the x and y axes of the subplots can be set using the feature_names property of the iris dataset as follows:
  1. ax1.set(xlabel=iris.feature_names[0],ylabel=iris.feature_names[1])  
  2. ax2.set(xlabel=iris.feature_names[0],ylabel=iris.feature_names[1])  
  3. ax3.set(xlabel=iris.feature_names[2],ylabel=iris.feature_names[3])  
  4. ax4.set(xlabel=iris.feature_names[2],ylabel=iris.feature_names[3])  
The following code can be used to set the background color of the subplots to green:
  1. ax1.set_facecolor("green")  
  2. ax2.set_facecolor("green")  
  3. ax3.set_facecolor("green")  
  4. ax4.set_facecolor("green")  
Finally, we can display the charts as follows:
Running the above code shows the following output:
Following is the implementation of the MeanShift clustering.
We create two MeanShift objects and fit the sepal and petal data as follows:
  1. from sklearn.cluster import MeanShift  
  2. ms1=MeanShift(n_jobs=-1).fit(sepal_data)  
  3. centroids_sepals=ms1.cluster_centers_  
  4. labels_sepals=ms1.labels_  
  5. ms2=MeanShift(n_jobs=-1).fit(petal_data)  
  6. centroids_petals=ms2.cluster_centers_  
  7. labels_petals=ms2.labels_  
Other steps are the same as KMeans clustering. Following is the output of MeanShift clustering:
Note that in MeanShift clustering, the number of clusters is automatically determined by the MeanShift algorithm.
The scipy.cluster.vq module provides the kmeans2 function to implement kmeans clustering. But it requires the data to be normalized before clustering. We can normalize the data by using the whiten function. We can implement kmeans clustering using scipy.cluster.vq module as follows:
  1. # Clustering using KMeans and Scipy  
  2. from sklearn import datasets  
  3. from scipy.cluster.vq import kmeans2,whiten  
  4. import matplotlib.pyplot as plt  
  5. from mpl_toolkits.mplot3d import Axes3D  
  6. iris=datasets.load_iris()  
  9. sepal_data_w=whiten(sepal_data)  
  10. petal_data_w=whiten(petal_data)  
  11. centroids_sepals,labels_sepals=kmeans2(k=3,data=sepal_data_w)  
  12. centroids_petals,labels_petals=kmeans2(k=3,data=petal_data_w)  
  13. fig=plt.figure()  
  14. ax1=fig.add_subplot(221)  
  15. ax2=fig.add_subplot(222,projection="3d")  
  16. ax3=fig.add_subplot(223)  
  17. ax4=fig.add_subplot(224,projection="3d")  
  18. ax1.scatter(sepal_data_w[:,0],sepal_data_w[:,1],c=labels_sepals,s=50)  
  19. ax1.scatter(centroids_sepals[:,0],centroids_sepals[:,1],c="red",s=100)  
  20. ax2.scatter(sepal_data_w[:,0],sepal_data_w[:,1],c=labels_sepals,s=50)  
  21. ax2.scatter(centroids_sepals[:,0],centroids_sepals[:,1],c="red",s=100)  
  22. ax3.scatter(petal_data_w[:,0],petal_data_w[:,1],c=labels_petals,s=50)  
  23. ax3.scatter(centroids_petals[:,0],centroids_petals[:,1],c="red",s=100)  
  24. ax4.scatter(petal_data_w[:,0],petal_data_w[:,1],c=labels_petals,s=50)  
  25. ax4.scatter(centroids_petals[:,0],centroids_petals[:,1],c="red",s=100)  
  26. ax1.set(xlabel=iris.feature_names[0],ylabel=iris.feature_names[1])  
  27. ax2.set(xlabel=iris.feature_names[0],ylabel=iris.feature_names[1])  
  28. ax3.set(xlabel=iris.feature_names[2],ylabel=iris.feature_names[3])  
  29. ax4.set(xlabel=iris.feature_names[2],ylabel=iris.feature_names[3])  
  30. ax1.set_facecolor("green")  
  31. ax2.set_facecolor("green")  
  32. ax3.set_facecolor("green")  
  33. ax4.set_facecolor("green")  
The above code produces the following output:


Data clustering is a very useful feature of data mining which finds many practical uses in the field of data classification and image processing. I hope readers find the article useful in understanding the concepts of data clustering.