import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'Image Segmentation using K-Means Clustering
Image Segmentation using K-Means Clustering
!wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg--2023-04-15 11:58:51-- https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg
Resolving segment-anything.com (segment-anything.com)... 108.158.245.28, 108.158.245.33, 108.158.245.84, ...
Connecting to segment-anything.com (segment-anything.com)|108.158.245.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 221810 (217K) [image/jpeg]
Saving to: ‘dog.jpg’
dog.jpg 100%[===================>] 216.61K 348KB/s in 0.6s
2023-04-15 11:58:53 (348 KB/s) - ‘dog.jpg’ saved [221810/221810]
# Read image and convert to RGB
img = plt.imread('dog.jpg')
# Convert to [0, 1] range
img = img / 255
# Plot image
plt.imshow(img)<matplotlib.image.AxesImage at 0x7efc88822340>

img.shape(1365, 2048, 3)
from sklearn.cluster import KMeans
# Reshape image to 2D array
img_2d = img.reshape(-1, 3)
pd.DataFrame(img_2d).describe()| 0 | 1 | 2 | |
|---|---|---|---|
| count | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 |
| mean | 5.542249e-01 | 4.757675e-01 | 2.935504e-01 |
| std | 2.068960e-01 | 1.844162e-01 | 1.533922e-01 |
| min | 3.921569e-02 | 0.000000e+00 | 0.000000e+00 |
| 25% | 3.882353e-01 | 3.176471e-01 | 1.647059e-01 |
| 50% | 5.725490e-01 | 4.705882e-01 | 2.666667e-01 |
| 75% | 7.294118e-01 | 6.392157e-01 | 4.078431e-01 |
| max | 1.000000e+00 | 1.000000e+00 | 1.000000e+00 |
img_2d.shape(2795520, 3)
np.unique(img_2d, axis=0).shape(190571, 3)
# Fit KMeans from scikit-learn (slow!)
kmeans = KMeans(n_clusters=5, random_state=0).fit(img_2d)kmeans.cluster_centers_, kmeans.labels_(array([[0.59260829, 0.48656705, 0.2670584 ],
[0.24250138, 0.22963615, 0.12538282],
[0.8270407 , 0.74152134, 0.54547183],
[0.71622954, 0.62308534, 0.39023071],
[0.419752 , 0.3354403 , 0.18325719]]),
array([1, 1, 1, ..., 4, 4, 4], dtype=int32))
# instead use FAISS with GPU
import faiss
# Set up FAISS index
d = img_2d.shape[1] # Dimension of the feature vectors
n_clusters = 5 # Number of clusters
n_gpus = 2 # Number of GPUs to use
# Initialize a multi-GPU IndexFlatL2 index
index_flat = faiss.IndexFlatL2(d)
index = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)kmeans_gpu = faiss.Clustering(d, n_clusters)
kmeans_gpu.verbose = True
kmeans_gpu.niter = 20
kmeans_gpu.train(img_2d.astype(np.float32), index)Sampling a subset of 1280 / 2795520 for training
Clustering 1280 points in 3D to 5 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.22 s, search 0.00 s): objective=12.4307 imbalance=1.031 nsplit=0
_, I = index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
cluster_ids = I.squeeze()# Create segmented image using cluster centers frmo sklearn
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_img = segmented_img.reshape(img.shape)
centroids_gpu = faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
# Create segmented image using cluster centers from FAISS
segmented_img_faiss = centroids_gpu[cluster_ids]
# Plot segmented image side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(segmented_img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape))<matplotlib.image.AxesImage at 0x7efa51d55250>

# Now, let's try to segment the image using a different number of clusters using FAISS
def segment_plot(img, k=5):
# Reshape image to 2D array
img_2d = img.reshape(-1, 3)
# Set up FAISS index
d = img_2d.shape[1] # Dimension of the feature vectors
n_clusters = k # Number of clusters
n_gpus = 2 # Number of GPUs to use
# Initialize a multi-GPU IndexFlatL2 index
index_flat = faiss.IndexFlatL2(d)
index = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
kmeans_gpu = faiss.Clustering(d, n_clusters)
kmeans_gpu.verbose = True
kmeans_gpu.niter = 20
kmeans_gpu.train(img_2d.astype(np.float32), index)
_, I = index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
cluster_ids = I.squeeze()
centroids_gpu = faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
# Create segmented image using cluster centers from FAISS
segmented_img_faiss = centroids_gpu[cluster_ids]
# Plot segmented image side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape))
# Plot the color of each cluster
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.imshow(centroids_gpu.reshape(1, n_clusters, 3))
segment_plot(img, k=2)Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.04 s
Iteration 19 (0.25 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0


segment_plot(img, k=10)Sampling a subset of 2560 / 2795520 for training
Clustering 2560 points in 3D to 10 clusters, redo 1 times, 20 iterations
Preprocessing in 0.04 s
Iteration 19 (0.03 s, search 0.01 s): objective=13.6225 imbalance=1.202 nsplit=0


# Modify segmentation function to mask out all but the given cluster
from copy import deepcopy
def segment_plot_mask(img, k=5):
# Reshape image to 2D array
img_2d = img.reshape(-1, 3)
# Set up FAISS index
d = img_2d.shape[1] # Dimension of the feature vectors
n_clusters = k # Number of clusters
n_gpus = 2 # Number of GPUs to use
# Initialize a multi-GPU IndexFlatL2 index
index_flat = faiss.IndexFlatL2(d)
index = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
kmeans_gpu = faiss.Clustering(d, n_clusters)
kmeans_gpu.verbose = True
kmeans_gpu.niter = 20
kmeans_gpu.train(img_2d.astype(np.float32), index)
_, I = index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
cluster_ids = I.squeeze()
centroids_gpu = faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
# Create segmented image using cluster centers from FAISS
segmented_img_faiss = centroids_gpu[cluster_ids]
# Plot segmented image side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img)
segmented_img_faiss = segmented_img_faiss.reshape(img.shape)
ax[1].imshow(segmented_img_faiss)
# Now, create another figure with number of cluser columns
# and plot each cluster with mask applied to the original image
# The mask is an alpha channel
fig, ax = plt.subplots(1, n_clusters, figsize=(n_clusters*4, 5))
for i in range(n_clusters):
img_masked = deepcopy(img_2d)
img_masked[cluster_ids != i] = 1.0
ax[i].imshow(img_masked.reshape(img.shape))
ax[i].set_title(f'Cluster {i}')
ax[i].axis('off')
img.shape(1365, 2048, 3)
mask = segment_plot_mask(img, k=2)Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.46 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0


mask = segment_plot_mask(img, k=3)Sampling a subset of 768 / 2795520 for training
Clustering 768 points in 3D to 3 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.46 s, search 0.00 s): objective=12.5181 imbalance=1.009 nsplit=0


mask = segment_plot_mask(img, k=4)Sampling a subset of 1024 / 2795520 for training
Clustering 1024 points in 3D to 4 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.15 s, search 0.00 s): objective=12.6442 imbalance=1.031 nsplit=0


# Now, let us segment the image using not only RGB but also the spatial coordinates
# Reshape image to 2D array
img_2d = img.reshape(-1, 3)
# Add spatial coordinates
x = np.arange(img.shape[0])
y = np.arange(img.shape[1])
# Scale the spatial coordinates to be between a and b
def scale(x, a, b):
return (b-a)*(x-x.min())/(x.max()-x.min()) + a
x = scale(x, 0.25, 0.75)
y = scale(y, 0.25, 0.75)
xx, yy = np.meshgrid(x, y)
xx = xx.reshape(-1, 1)
yy = yy.reshape(-1, 1)
img_2d_spatial = np.hstack((img_2d, xx, yy))
pd.DataFrame(img_2d_spatial).describe()| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| count | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 |
| mean | 5.542249e-01 | 4.757675e-01 | 2.935504e-01 | 5.000000e-01 | 5.000000e-01 |
| std | 2.068960e-01 | 1.844162e-01 | 1.533922e-01 | 1.444434e-01 | 1.444081e-01 |
| min | 3.921569e-02 | 0.000000e+00 | 0.000000e+00 | 2.500000e-01 | 2.500000e-01 |
| 25% | 3.882353e-01 | 3.176471e-01 | 1.647059e-01 | 3.750000e-01 | 3.750000e-01 |
| 50% | 5.725490e-01 | 4.705882e-01 | 2.666667e-01 | 5.000000e-01 | 5.000000e-01 |
| 75% | 7.294118e-01 | 6.392157e-01 | 4.078431e-01 | 6.250000e-01 | 6.250000e-01 |
| max | 1.000000e+00 | 1.000000e+00 | 1.000000e+00 | 7.500000e-01 | 7.500000e-01 |
# Now, modify the segment_plot function to include spatial coordinates
def segment_plot_spatial(img, k=5):
# Reshape image to 2D array
img_2d = img.reshape(-1, 3)
# Add spatial coordinates
x = np.arange(img.shape[0])
y = np.arange(img.shape[1])
# Scale the spatial coordinates to be between a and b
x = scale(x, 0, 1)
y = scale(y, 0, 1)
xx, yy = np.meshgrid(x, y)
xx = xx.reshape(-1, 1)
yy = yy.reshape(-1, 1)
img_2d_spatial = np.hstack((img_2d, xx, yy))
# Set up FAISS index
d = img_2d_spatial.shape[1] # Dimension of the feature vectors
n_clusters = k # Number of clusters
n_gpus = 2 # Number of GPUs to use
# Initialize a multi-GPU IndexFlatL2 index
index_flat = faiss.IndexFlatL2(d)
index = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
kmeans_gpu = faiss.Clustering(d, n_clusters)
kmeans_gpu.verbose = True
kmeans_gpu.niter = 20
kmeans_gpu.train(img_2d_spatial.astype(np.float32), index)
_, I = index.search(img_2d_spatial.astype(np.float32), 1) # Search for nearest centroid
cluster_ids = I.squeeze()
centroids_gpu = faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
# Create segmented image using cluster centers from FAISS with spatial coordinates excluded for plotting
segmented_img_faiss = centroids_gpu[cluster_ids, :3]
# Plot segmented image side by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape))
# Now, create another figure with number of cluser columns
# and plot each cluster with mask applied to the original image
# The mask is an alpha channel
fig, ax = plt.subplots(1, n_clusters, figsize=(n_clusters*4, 5))
for i in range(n_clusters):
img_masked = deepcopy(img_2d_spatial[:, :3 ])
img_masked[cluster_ids != i] = 1.0
ax[i].imshow(img_masked.reshape(img.shape))
ax[i].set_title(f'Cluster {i}')
ax[i].axis('off')segment_plot_spatial(img, k=2)Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 5D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.44 s, search 0.00 s): objective=73.9492 imbalance=1.018 nsplit=0


segment_plot_mask(img, k=2)Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.47 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0

