Project Setup¶
In [1]:
import mlx.core as mx
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm, trange
from pathlib import Path
from matplotlib.colors import ListedColormap, BoundaryNorm
import ipywidgets as widgets
from KSS import KSS
File Location¶
In [2]:
ROOT = Path.cwd()
MAT_DIR = ROOT / "MAT Files"
GT_DIR = ROOT / "GT Files"
Datasets - Files & Keys¶
In [3]:
DATASETS = {
"Pavia": {
"data_file": "Pavia.mat",
"gt_file": "Pavia_gt.mat",
"data_key": "pavia",
"gt_key": "pavia_gt",
},
"PaviaUni": {
"data_file": "PaviaUni.mat",
"gt_file": "PaviaU_gt.mat",
"data_key": "paviaU",
"gt_key": "paviaU_gt",
},
"Salinas": {
"data_file": "Salinas_corrected.mat",
"gt_file": "Salinas_gt.mat",
"data_key": "salinas_corrected",
"gt_key": "salinas_gt",
},
"Indian pines":{
"data_file": "Indian_pines.mat",
"gt_file": "Indian_pines_gt.mat",
"data_key": "indian_pines",
"gt_key": "indian_pines_gt",
}
}
In [4]:
DEFAULT_DS = "Pavia"
ds = widgets.Dropdown(options=list(DATASETS.keys()),
value=DEFAULT_DS,
description="Dataset:")
out = widgets.Output()
display(ds, out)
Dropdown(description='Dataset:', options=('Pavia', 'PaviaUni', 'Salinas', 'Indian pines'), value='Pavia')
Output()
In [5]:
keys = DATASETS[ds.value]
In [6]:
keys
Out[6]:
{'data_file': 'Pavia.mat',
'gt_file': 'Pavia_gt.mat',
'data_key': 'pavia',
'gt_key': 'pavia_gt'}
Hyperspectral Image Cube¶
In [7]:
def load_ds(name: str):
"""Load data for the selected dataset and put into globals()."""
cfg = DATASETS[name]
data_mat = sio.loadmat(os.path.join(MAT_DIR, cfg["data_file"]))
gt_mat = sio.loadmat(os.path.join(GT_DIR, cfg["gt_file"]))
globals()["data_cube"] = data_mat[cfg["data_key"]]
globals()["gt_data"] = gt_mat[cfg["gt_key"]]
with out:
out.clear_output(wait=True)
print(f"Loaded {name}: X {data_cube.shape}, GT {gt_data.shape}")
In [8]:
def _on_change(change):
if change["name"] == "value":
load_ds(change["new"])
In [9]:
ds.observe(_on_change, names="value")
In [10]:
load_ds(ds.value)
Ground Truth Data¶
In [11]:
gt_data
Out[11]:
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(1096, 715), dtype=uint8)
In [12]:
gt_data.shape
Out[12]:
(1096, 715)
In [13]:
gt_labels = np.sort(np.unique(gt_data))
In [14]:
num_labels = len(gt_labels)
In [16]:
bg_indices = gt_data == 0
In [17]:
filter = mx.ones([data_cube.shape[0], data_cube.shape[1]])
In [18]:
mask = mx.where(bg_indices, 0.0, filter)
In [20]:
mask3d = mask[..., None]
In [21]:
masked_cube = mx.multiply(data_cube, mask3d)
In [25]:
masked_cube_reshaped = mx.reshape(masked_cube, (masked_cube.shape[0]*masked_cube.shape[1], masked_cube.shape[2]))
In [28]:
cube_np = np.array(masked_cube_reshaped)
nonzero_mask = np.any(cube_np != 0, axis=1)
filtered_np = cube_np[nonzero_mask]
filtered_pixels = mx.array(filtered_np)
Fit the Model¶
In [29]:
model = KSS(
n_clusters=num_labels-1,
subspaces_dims=2,
max_iter=100,
n_init=150,
verbose=1,
random_state=10)
In [30]:
%%time
model.fit(filtered_pixels.T)
Running KSS with n_clusters=9, subspaces_dims=[2, 2, 2, 2, 2, 2, 2, 2, 2], max_iter=100, n_init=150 KSS run 1/150 Run cost: 2.4692e+13 KSS run 2/150 Run cost: 2.4692e+13 KSS run 3/150 Run cost: 2.4693e+13 KSS run 4/150 Run cost: 2.4691e+13 KSS run 5/150 Run cost: 2.4688e+13 KSS run 6/150 Run cost: 2.4692e+13 KSS run 7/150 Run cost: 2.4692e+13 KSS run 8/150 Run cost: 2.4692e+13 KSS run 9/150 Run cost: 2.4689e+13 KSS run 10/150 Run cost: 2.4691e+13 KSS run 11/150 Run cost: 2.4690e+13 KSS run 12/150 Run cost: 2.4690e+13 KSS run 13/150 Run cost: 2.4689e+13 KSS run 14/150 Run cost: 2.4690e+13 KSS run 15/150 Run cost: 2.4691e+13 KSS run 16/150 Run cost: 2.4691e+13 KSS run 17/150 Run cost: 2.4693e+13 KSS run 18/150 Run cost: 2.4692e+13 KSS run 19/150 Run cost: 2.4691e+13 KSS run 20/150 Run cost: 2.4685e+13 KSS run 21/150 Run cost: 2.4688e+13 KSS run 22/150 Run cost: 2.4690e+13 KSS run 23/150 Run cost: 2.4691e+13 KSS run 24/150 Run cost: 2.4688e+13 KSS run 25/150 Run cost: 2.4691e+13 KSS run 26/150 Run cost: 2.4688e+13 KSS run 27/150 Run cost: 2.4691e+13 KSS run 28/150 Run cost: 2.4692e+13 KSS run 29/150 Run cost: 2.4690e+13 KSS run 30/150 Run cost: 2.4688e+13 KSS run 31/150 Run cost: 2.4692e+13 KSS run 32/150 Run cost: 2.4691e+13 KSS run 33/150 Run cost: 2.4688e+13 KSS run 34/150 Run cost: 2.4690e+13 KSS run 35/150 Run cost: 2.4687e+13 KSS run 36/150 Run cost: 2.4690e+13 KSS run 37/150 Run cost: 2.4688e+13 KSS run 38/150 Run cost: 2.4689e+13 KSS run 39/150 Run cost: 2.4688e+13 KSS run 40/150 Run cost: 2.4691e+13 KSS run 41/150 Run cost: 2.4690e+13 KSS run 42/150 Run cost: 2.4690e+13 KSS run 43/150 Run cost: 2.4691e+13 KSS run 44/150 Run cost: 2.4690e+13 KSS run 45/150 Run cost: 2.4691e+13 KSS run 46/150 Run cost: 2.4693e+13 KSS run 47/150 Run cost: 2.4693e+13 KSS run 48/150 Run cost: 2.4687e+13 KSS run 49/150 Run cost: 2.4689e+13 KSS run 50/150 Run cost: 2.4691e+13 KSS run 51/150 Run cost: 2.4690e+13 KSS run 52/150 Run cost: 2.4691e+13 KSS run 53/150 Run cost: 2.4691e+13 KSS run 54/150 Run cost: 2.4686e+13 KSS run 55/150 Run cost: 2.4691e+13 KSS run 56/150 Run cost: 2.4690e+13 KSS run 57/150 Run cost: 2.4690e+13 KSS run 58/150 Run cost: 2.4692e+13 KSS run 59/150 Run cost: 2.4687e+13 KSS run 60/150 Run cost: 2.4691e+13 KSS run 61/150 Run cost: 2.4690e+13 KSS run 62/150 Run cost: 2.4690e+13 KSS run 63/150 Run cost: 2.4692e+13 KSS run 64/150 Run cost: 2.4692e+13 KSS run 65/150 Run cost: 2.4689e+13 KSS run 66/150 Run cost: 2.4692e+13 KSS run 67/150 Run cost: 2.4692e+13 KSS run 68/150 Run cost: 2.4691e+13 KSS run 69/150 Run cost: 2.4691e+13 KSS run 70/150 Run cost: 2.4690e+13 KSS run 71/150 Run cost: 2.4690e+13 KSS run 72/150 Run cost: 2.4687e+13 KSS run 73/150 Run cost: 2.4692e+13 KSS run 74/150 Run cost: 2.4692e+13 KSS run 75/150 Run cost: 2.4689e+13 KSS run 76/150 Run cost: 2.4689e+13 KSS run 77/150 Run cost: 2.4685e+13 KSS run 78/150 Run cost: 2.4689e+13 KSS run 79/150 Run cost: 2.4691e+13 KSS run 80/150 Run cost: 2.4690e+13 KSS run 81/150 Run cost: 2.4690e+13 KSS run 82/150 Run cost: 2.4691e+13 KSS run 83/150 Run cost: 2.4691e+13 KSS run 84/150 Run cost: 2.4691e+13 KSS run 85/150 Run cost: 2.4691e+13 KSS run 86/150 Run cost: 2.4689e+13 KSS run 87/150 Run cost: 2.4689e+13 KSS run 88/150 Run cost: 2.4689e+13 KSS run 89/150 Run cost: 2.4691e+13 KSS run 90/150 Run cost: 2.4690e+13 KSS run 91/150 Run cost: 2.4691e+13 KSS run 92/150 Run cost: 2.4686e+13 KSS run 93/150 Run cost: 2.4690e+13 KSS run 94/150 Run cost: 2.4692e+13 KSS run 95/150 Run cost: 2.4691e+13 KSS run 96/150 Run cost: 2.4691e+13 KSS run 97/150 Run cost: 2.4684e+13 KSS run 98/150 Run cost: 2.4690e+13 KSS run 99/150 Run cost: 2.4689e+13 KSS run 100/150 Run cost: 2.4691e+13 KSS run 101/150 Run cost: 2.4688e+13 KSS run 102/150 Run cost: 2.4690e+13 KSS run 103/150 Run cost: 2.4692e+13 KSS run 104/150 Run cost: 2.4692e+13 KSS run 105/150 Run cost: 2.4691e+13 KSS run 106/150 Run cost: 2.4690e+13 KSS run 107/150 Run cost: 2.4690e+13 KSS run 108/150 Run cost: 2.4691e+13 KSS run 109/150 Run cost: 2.4691e+13 KSS run 110/150 Run cost: 2.4689e+13 KSS run 111/150 Run cost: 2.4684e+13 KSS run 112/150 Run cost: 2.4691e+13 KSS run 113/150 Run cost: 2.4688e+13 KSS run 114/150 Run cost: 2.4691e+13 KSS run 115/150 Run cost: 2.4691e+13 KSS run 116/150 Run cost: 2.4691e+13 KSS run 117/150 Run cost: 2.4691e+13 KSS run 118/150 Run cost: 2.4686e+13 KSS run 119/150 Run cost: 2.4690e+13 KSS run 120/150 Run cost: 2.4691e+13 KSS run 121/150 Run cost: 2.4691e+13 KSS run 122/150 Run cost: 2.4689e+13 KSS run 123/150 Run cost: 2.4690e+13 KSS run 124/150 Run cost: 2.4693e+13 KSS run 125/150 Run cost: 2.4686e+13 KSS run 126/150 Run cost: 2.4691e+13 KSS run 127/150 Run cost: 2.4690e+13 KSS run 128/150 Run cost: 2.4690e+13 KSS run 129/150 Run cost: 2.4692e+13 KSS run 130/150 Run cost: 2.4691e+13 KSS run 131/150 Run cost: 2.4691e+13 KSS run 132/150 Run cost: 2.4691e+13 KSS run 133/150 Run cost: 2.4691e+13 KSS run 134/150 Run cost: 2.4691e+13 KSS run 135/150 Run cost: 2.4682e+13 KSS run 136/150 Run cost: 2.4689e+13 KSS run 137/150 Run cost: 2.4690e+13 KSS run 138/150 Run cost: 2.4692e+13 KSS run 139/150 Run cost: 2.4691e+13 KSS run 140/150 Run cost: 2.4692e+13 KSS run 141/150 Run cost: 2.4691e+13 KSS run 142/150 Run cost: 2.4690e+13 KSS run 143/150 Run cost: 2.4688e+13 KSS run 144/150 Run cost: 2.4690e+13 KSS run 145/150 Run cost: 2.4686e+13 KSS run 146/150 Run cost: 2.4691e+13 KSS run 147/150 Run cost: 2.4690e+13 KSS run 148/150 Run cost: 2.4690e+13 KSS run 149/150 Run cost: 2.4691e+13 KSS run 150/150 Run cost: 2.4690e+13 CPU times: user 2min, sys: 1min 26s, total: 3min 26s Wall time: 9min 27s
Out[30]:
<KSS.kss_mlx.KSS at 0x13df1f5f0>
In [31]:
labels = model.labels_
In [41]:
# labels
In [42]:
for k in range(1, num_labels):
ilist = np.nonzero(labels == k)[0]
print(f"Cluster {k}: {ilist}")
Cluster 1: [ 4 5 16 ... 104460 104580 105652] Cluster 2: [ 0 9 10 ... 148007 148015 148135] Cluster 3: [ 137 138 139 ... 148149 148150 148151] Cluster 4: [ 63 133 201 ... 148119 148141 148143] Cluster 5: [ 210 211 212 ... 145464 147928 148142] Cluster 6: [ 323 398 630 ... 148130 148132 148148] Cluster 7: [ 1 2 3 ... 145465 145511 146777] Cluster 8: [ 67 68 931 ... 144673 144682 144891] Cluster 9: [ 135 136 204 ... 148140 148144 148145]
Creating a clustermap¶
In [34]:
label_px = gt_data > 0 # Returns a boolean map with background pixels stored as False
clustermap = np.zeros_like(gt_data, dtype=np.int32) # Create a zero filled 2D array size of gt_data
clustermap[label_px] = labels # Fills the foreground pixels
In [ ]:
Plot - Ground Truth Vs Cluster assignment¶
In [35]:
def plot_gt_vs_clusters(gt_data, clustermap, title_left="Ground Truth", title_right="KSS Clustering Results"):
gt_max = int(np.max(gt_data))
km_max = int(np.max(clustermap))
n_classes = max(gt_max, km_max)
# color 0 = background (black); then categorical colors for 1..n_classes
base = plt.cm.tab20.colors # plenty of distinct colors
palette = ['black'] + [base[i % len(base)] for i in range(n_classes)]
cmap = ListedColormap(palette)
norm = BoundaryNorm(np.arange(n_classes + 2) - 0.5, cmap.N) # hard bins at integers
fig, axes = plt.subplots(1, 2, figsize=(9, 7), constrained_layout=True)
im0 = axes[0].imshow(gt_data, cmap=cmap, norm=norm, interpolation='nearest')
axes[0].set_title(title_left, fontsize=16)
axes[0].axis('equal'); axes[0].set_xticks([]); axes[0].set_yticks([])
im1 = axes[1].imshow(clustermap, cmap=cmap, norm=norm, interpolation='nearest')
axes[1].set_title(title_right, fontsize=16)
axes[1].axis('equal'); axes[1].set_xticks([]); axes[1].set_yticks([])
# one colorbar per panel with integer ticks (0..n_classes)
ticks = np.arange(0, n_classes + 1, 1)
cbar0 = fig.colorbar(im0, ax=axes[0], ticks=ticks, orientation='horizontal', fraction=0.046, pad=0.05)
cbar1 = fig.colorbar(im1, ax=axes[1], ticks=ticks, orientation='horizontal', fraction=0.046, pad=0.05)
plt.show()
In [36]:
plot_gt_vs_clusters(gt_data, clustermap)
In [ ]: