import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Image Segmentation using K-Means Clustering
Image Segmentation using K-Means Clustering
!wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg
--2023-04-15 11:58:51-- https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg
Resolving segment-anything.com (segment-anything.com)... 108.158.245.28, 108.158.245.33, 108.158.245.84, ...
Connecting to segment-anything.com (segment-anything.com)|108.158.245.28|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 221810 (217K) [image/jpeg]
Saving to: ‘dog.jpg’
dog.jpg 100%[===================>] 216.61K 348KB/s in 0.6s
2023-04-15 11:58:53 (348 KB/s) - ‘dog.jpg’ saved [221810/221810]
# Read image and convert to RGB
= plt.imread('dog.jpg')
img
# Convert to [0, 1] range
= img / 255
img
# Plot image
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7efc88822340>
img.shape
(1365, 2048, 3)
from sklearn.cluster import KMeans
# Reshape image to 2D array
= img.reshape(-1, 3)
img_2d pd.DataFrame(img_2d).describe()
0 | 1 | 2 | |
---|---|---|---|
count | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 |
mean | 5.542249e-01 | 4.757675e-01 | 2.935504e-01 |
std | 2.068960e-01 | 1.844162e-01 | 1.533922e-01 |
min | 3.921569e-02 | 0.000000e+00 | 0.000000e+00 |
25% | 3.882353e-01 | 3.176471e-01 | 1.647059e-01 |
50% | 5.725490e-01 | 4.705882e-01 | 2.666667e-01 |
75% | 7.294118e-01 | 6.392157e-01 | 4.078431e-01 |
max | 1.000000e+00 | 1.000000e+00 | 1.000000e+00 |
img_2d.shape
(2795520, 3)
=0).shape np.unique(img_2d, axis
(190571, 3)
# Fit KMeans from scikit-learn (slow!)
= KMeans(n_clusters=5, random_state=0).fit(img_2d) kmeans
kmeans.cluster_centers_, kmeans.labels_
(array([[0.59260829, 0.48656705, 0.2670584 ],
[0.24250138, 0.22963615, 0.12538282],
[0.8270407 , 0.74152134, 0.54547183],
[0.71622954, 0.62308534, 0.39023071],
[0.419752 , 0.3354403 , 0.18325719]]),
array([1, 1, 1, ..., 4, 4, 4], dtype=int32))
# instead use FAISS with GPU
import faiss
# Set up FAISS index
= img_2d.shape[1] # Dimension of the feature vectors
d = 5 # Number of clusters
n_clusters = 2 # Number of GPUs to use
n_gpus
# Initialize a multi-GPU IndexFlatL2 index
= faiss.IndexFlatL2(d)
index_flat = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus) index
= faiss.Clustering(d, n_clusters)
kmeans_gpu = True
kmeans_gpu.verbose = 20
kmeans_gpu.niter kmeans_gpu.train(img_2d.astype(np.float32), index)
Sampling a subset of 1280 / 2795520 for training
Clustering 1280 points in 3D to 5 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.22 s, search 0.00 s): objective=12.4307 imbalance=1.031 nsplit=0
= index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
_, I = I.squeeze() cluster_ids
# Create segmented image using cluster centers frmo sklearn
= kmeans.cluster_centers_[kmeans.labels_]
segmented_img = segmented_img.reshape(img.shape)
segmented_img
= faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
centroids_gpu # Create segmented image using cluster centers from FAISS
= centroids_gpu[cluster_ids]
segmented_img_faiss
# Plot segmented image side by side
= plt.subplots(1, 2, figsize=(10, 5))
fig, ax 0].imshow(segmented_img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape)) ax[
<matplotlib.image.AxesImage at 0x7efa51d55250>
# Now, let's try to segment the image using a different number of clusters using FAISS
def segment_plot(img, k=5):
# Reshape image to 2D array
= img.reshape(-1, 3)
img_2d
# Set up FAISS index
= img_2d.shape[1] # Dimension of the feature vectors
d = k # Number of clusters
n_clusters = 2 # Number of GPUs to use
n_gpus
# Initialize a multi-GPU IndexFlatL2 index
= faiss.IndexFlatL2(d)
index_flat = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
index
= faiss.Clustering(d, n_clusters)
kmeans_gpu = True
kmeans_gpu.verbose = 20
kmeans_gpu.niter
kmeans_gpu.train(img_2d.astype(np.float32), index)
= index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
_, I = I.squeeze()
cluster_ids
= faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
centroids_gpu # Create segmented image using cluster centers from FAISS
= centroids_gpu[cluster_ids]
segmented_img_faiss
# Plot segmented image side by side
= plt.subplots(1, 2, figsize=(10, 5))
fig, ax 0].imshow(img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape))
ax[
# Plot the color of each cluster
= plt.subplots(1, 1, figsize=(4, 4))
fig, ax 1, n_clusters, 3))
ax.imshow(centroids_gpu.reshape(
=2) segment_plot(img, k
Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.04 s
Iteration 19 (0.25 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0
=10) segment_plot(img, k
Sampling a subset of 2560 / 2795520 for training
Clustering 2560 points in 3D to 10 clusters, redo 1 times, 20 iterations
Preprocessing in 0.04 s
Iteration 19 (0.03 s, search 0.01 s): objective=13.6225 imbalance=1.202 nsplit=0
# Modify segmentation function to mask out all but the given cluster
from copy import deepcopy
def segment_plot_mask(img, k=5):
# Reshape image to 2D array
= img.reshape(-1, 3)
img_2d
# Set up FAISS index
= img_2d.shape[1] # Dimension of the feature vectors
d = k # Number of clusters
n_clusters = 2 # Number of GPUs to use
n_gpus
# Initialize a multi-GPU IndexFlatL2 index
= faiss.IndexFlatL2(d)
index_flat = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
index
= faiss.Clustering(d, n_clusters)
kmeans_gpu = True
kmeans_gpu.verbose = 20
kmeans_gpu.niter
kmeans_gpu.train(img_2d.astype(np.float32), index)
= index.search(img_2d.astype(np.float32), 1) # Search for nearest centroid
_, I = I.squeeze()
cluster_ids
= faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
centroids_gpu # Create segmented image using cluster centers from FAISS
= centroids_gpu[cluster_ids]
segmented_img_faiss
# Plot segmented image side by side
= plt.subplots(1, 2, figsize=(10, 5))
fig, ax 0].imshow(img)
ax[= segmented_img_faiss.reshape(img.shape)
segmented_img_faiss 1].imshow(segmented_img_faiss)
ax[
# Now, create another figure with number of cluser columns
# and plot each cluster with mask applied to the original image
# The mask is an alpha channel
= plt.subplots(1, n_clusters, figsize=(n_clusters*4, 5))
fig, ax for i in range(n_clusters):
= deepcopy(img_2d)
img_masked != i] = 1.0
img_masked[cluster_ids
ax[i].imshow(img_masked.reshape(img.shape))f'Cluster {i}')
ax[i].set_title('off')
ax[i].axis(
img.shape
(1365, 2048, 3)
= segment_plot_mask(img, k=2) mask
Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.46 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0
= segment_plot_mask(img, k=3) mask
Sampling a subset of 768 / 2795520 for training
Clustering 768 points in 3D to 3 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.46 s, search 0.00 s): objective=12.5181 imbalance=1.009 nsplit=0
= segment_plot_mask(img, k=4) mask
Sampling a subset of 1024 / 2795520 for training
Clustering 1024 points in 3D to 4 clusters, redo 1 times, 20 iterations
Preprocessing in 0.02 s
Iteration 19 (0.15 s, search 0.00 s): objective=12.6442 imbalance=1.031 nsplit=0
# Now, let us segment the image using not only RGB but also the spatial coordinates
# Reshape image to 2D array
= img.reshape(-1, 3)
img_2d
# Add spatial coordinates
= np.arange(img.shape[0])
x = np.arange(img.shape[1])
y
# Scale the spatial coordinates to be between a and b
def scale(x, a, b):
return (b-a)*(x-x.min())/(x.max()-x.min()) + a
= scale(x, 0.25, 0.75)
x = scale(y, 0.25, 0.75)
y
= np.meshgrid(x, y)
xx, yy = xx.reshape(-1, 1)
xx = yy.reshape(-1, 1)
yy
= np.hstack((img_2d, xx, yy))
img_2d_spatial pd.DataFrame(img_2d_spatial).describe()
0 | 1 | 2 | 3 | 4 | |
---|---|---|---|---|---|
count | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 | 2.795520e+06 |
mean | 5.542249e-01 | 4.757675e-01 | 2.935504e-01 | 5.000000e-01 | 5.000000e-01 |
std | 2.068960e-01 | 1.844162e-01 | 1.533922e-01 | 1.444434e-01 | 1.444081e-01 |
min | 3.921569e-02 | 0.000000e+00 | 0.000000e+00 | 2.500000e-01 | 2.500000e-01 |
25% | 3.882353e-01 | 3.176471e-01 | 1.647059e-01 | 3.750000e-01 | 3.750000e-01 |
50% | 5.725490e-01 | 4.705882e-01 | 2.666667e-01 | 5.000000e-01 | 5.000000e-01 |
75% | 7.294118e-01 | 6.392157e-01 | 4.078431e-01 | 6.250000e-01 | 6.250000e-01 |
max | 1.000000e+00 | 1.000000e+00 | 1.000000e+00 | 7.500000e-01 | 7.500000e-01 |
# Now, modify the segment_plot function to include spatial coordinates
def segment_plot_spatial(img, k=5):
# Reshape image to 2D array
= img.reshape(-1, 3)
img_2d
# Add spatial coordinates
= np.arange(img.shape[0])
x = np.arange(img.shape[1])
y
# Scale the spatial coordinates to be between a and b
= scale(x, 0, 1)
x = scale(y, 0, 1)
y
= np.meshgrid(x, y)
xx, yy = xx.reshape(-1, 1)
xx = yy.reshape(-1, 1)
yy
= np.hstack((img_2d, xx, yy))
img_2d_spatial
# Set up FAISS index
= img_2d_spatial.shape[1] # Dimension of the feature vectors
d = k # Number of clusters
n_clusters = 2 # Number of GPUs to use
n_gpus
# Initialize a multi-GPU IndexFlatL2 index
= faiss.IndexFlatL2(d)
index_flat = faiss.index_cpu_to_all_gpus(index_flat, ngpu=n_gpus)
index
= faiss.Clustering(d, n_clusters)
kmeans_gpu = True
kmeans_gpu.verbose = 20
kmeans_gpu.niter
kmeans_gpu.train(img_2d_spatial.astype(np.float32), index)
= index.search(img_2d_spatial.astype(np.float32), 1) # Search for nearest centroid
_, I = I.squeeze()
cluster_ids
= faiss.vector_float_to_array(kmeans_gpu.centroids).reshape(n_clusters, d)
centroids_gpu # Create segmented image using cluster centers from FAISS with spatial coordinates excluded for plotting
= centroids_gpu[cluster_ids, :3]
segmented_img_faiss
# Plot segmented image side by side
= plt.subplots(1, 2, figsize=(10, 5))
fig, ax 0].imshow(img)
ax[1].imshow(segmented_img_faiss.reshape(img.shape))
ax[
# Now, create another figure with number of cluser columns
# and plot each cluster with mask applied to the original image
# The mask is an alpha channel
= plt.subplots(1, n_clusters, figsize=(n_clusters*4, 5))
fig, ax for i in range(n_clusters):
= deepcopy(img_2d_spatial[:, :3 ])
img_masked != i] = 1.0
img_masked[cluster_ids
ax[i].imshow(img_masked.reshape(img.shape))f'Cluster {i}')
ax[i].set_title('off') ax[i].axis(
=2) segment_plot_spatial(img, k
Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 5D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.44 s, search 0.00 s): objective=73.9492 imbalance=1.018 nsplit=0
=2) segment_plot_mask(img, k
Sampling a subset of 512 / 2795520 for training
Clustering 512 points in 3D to 2 clusters, redo 1 times, 20 iterations
Preprocessing in 0.03 s
Iteration 19 (0.47 s, search 0.00 s): objective=14.8063 imbalance=1.000 nsplit=0