Skip to main content

K-means Clustering

Discover Hidden Patterns in Your Data! 🎯

K-means clustering is one of the most popular unsupervised learning algorithms, used to automatically group similar data points without labeled examples. From customer segmentation to image compression, K-means helps uncover natural groupings in data. Master this fundamental algorithm including initialization strategies, optimal K selection, and advanced variants.

K-means Algorithm Overview

graph TD A[K-means Clustering] --> B[Algorithm Steps] A --> C[Key Concepts] A --> D[Applications] B --> E[1. Initialize K centroids] B --> F[2. Assign points to nearest centroid] B --> G[3. Update centroids] B --> H[4. Repeat until convergence] C --> I[Centroids] C --> J[Inertia/SSE] C --> K[Euclidean Distance] C --> L[Convergence] D --> M[Customer Segmentation] D --> N[Image Compression] D --> O[Document Clustering] D --> P[Anomaly Detection] style A fill:#f9f,stroke:#333,stroke-width:2px style B fill:#bbf,stroke:#333,stroke-width:2px style C fill:#fbf,stroke:#333,stroke-width:2px style D fill:#bfb,stroke:#333,stroke-width:2px

K-means Implementation from Scratch

Understanding the Algorithm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs, make_circles, make_moons
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, silhouette_samples, calinski_harabasz_score
from sklearn.metrics import davies_bouldin_score, adjusted_rand_score, adjusted_mutual_info_score
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Generate sample data
np.random.seed(42)
X, y_true = make_blobs(n_samples=300, centers=4, n_features=2,
                       center_box=(-10.0, 10.0), cluster_std=1.5,
                       random_state=42)

class KMeansFromScratch:
    """K-means clustering implementation from scratch"""
    
    def __init__(self, n_clusters=3, max_iters=100, random_state=None):
        self.n_clusters = n_clusters
        self.max_iters = max_iters
        self.random_state = random_state
        self.centroids = None
        self.labels = None
        self.inertia_ = None
        self.n_iter_ = 0
        self.history = {'centroids': [], 'labels': [], 'inertia': []}
        
    def euclidean_distance(self, X, centroids):
        """Calculate Euclidean distance between points and centroids"""
        distances = np.zeros((X.shape[0], len(centroids)))
        for k, centroid in enumerate(centroids):
            distances[:, k] = np.linalg.norm(X - centroid, axis=1)
        return distances
    
    def initialize_centroids(self, X, method='random'):
        """Initialize centroids using different methods"""
        np.random.seed(self.random_state)
        n_samples = X.shape[0]
        
        if method == 'random':
            # Random initialization
            idx = np.random.choice(n_samples, self.n_clusters, replace=False)
            centroids = X[idx]
            
        elif method == 'kmeans++':
            # K-means++ initialization
            centroids = []
            
            # Choose first centroid randomly
            first_idx = np.random.randint(n_samples)
            centroids.append(X[first_idx])
            
            for _ in range(1, self.n_clusters):
                # Calculate distances to nearest centroid
                distances = np.array([min([np.linalg.norm(x - c)**2 for c in centroids]) 
                                     for x in X])
                
                # Choose next centroid with probability proportional to squared distance
                probabilities = distances / distances.sum()
                cumprobs = probabilities.cumsum()
                r = np.random.rand()
                
                for j, p in enumerate(cumprobs):
                    if r < p:
                        centroids.append(X[j])
                        break
            
            centroids = np.array(centroids)
            
        return centroids
    
    def fit(self, X, init_method='kmeans++'):
        """Fit K-means clustering"""
        # Initialize centroids
        self.centroids = self.initialize_centroids(X, method=init_method)
        
        for iteration in range(self.max_iters):
            # Store history
            self.history['centroids'].append(self.centroids.copy())
            
            # Assign clusters
            distances = self.euclidean_distance(X, self.centroids)
            self.labels = np.argmin(distances, axis=1)
            self.history['labels'].append(self.labels.copy())
            
            # Calculate inertia
            self.inertia_ = 0
            for k in range(self.n_clusters):
                cluster_points = X[self.labels == k]
                if len(cluster_points) > 0:
                    self.inertia_ += np.sum((cluster_points - self.centroids[k])**2)
            self.history['inertia'].append(self.inertia_)
            
            # Update centroids
            new_centroids = np.zeros((self.n_clusters, X.shape[1]))
            for k in range(self.n_clusters):
                cluster_points = X[self.labels == k]
                if len(cluster_points) > 0:
                    new_centroids[k] = np.mean(cluster_points, axis=0)
            
            # Check convergence
            if np.allclose(self.centroids, new_centroids):
                self.n_iter_ = iteration + 1
                break
            
            self.centroids = new_centroids
        
        return self
    
    def predict(self, X):
        """Predict cluster labels for new data"""
        distances = self.euclidean_distance(X, self.centroids)
        return np.argmin(distances, axis=1)

# Implement K-means from scratch
print("="*60)
print("K-MEANS FROM SCRATCH")
print("="*60)

kmeans_scratch = KMeansFromScratch(n_clusters=4, max_iters=100, random_state=42)
kmeans_scratch.fit(X, init_method='kmeans++')

print(f"\nConverged in {kmeans_scratch.n_iter_} iterations")
print(f"Final inertia: {kmeans_scratch.inertia_:.2f}")
print(f"Cluster centers:\n{kmeans_scratch.centroids}")

# Compare with scikit-learn
kmeans_sklearn = KMeans(n_clusters=4, random_state=42, n_init=10)
kmeans_sklearn.fit(X)

print(f"\nScikit-learn inertia: {kmeans_sklearn.inertia_:.2f}")
print(f"ARI (scratch vs sklearn): {adjusted_rand_score(kmeans_scratch.labels, kmeans_sklearn.labels_):.3f}")

Key Takeaways

Further Resources