Project Setup¶

In [1]:
import mlx.core as mx
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm, trange
from pathlib import Path
from matplotlib.colors import ListedColormap, BoundaryNorm
import ipywidgets as widgets
from KSS import KSS

File Location¶

In [2]:
ROOT = Path.cwd()
MAT_DIR = ROOT / "MAT Files"
GT_DIR = ROOT / "GT Files"

Datasets - Files & Keys¶

In [3]:
DATASETS = {
    "Pavia": {
        "data_file": "Pavia.mat",
        "gt_file":   "Pavia_gt.mat",
        "data_key":  "pavia",
        "gt_key":    "pavia_gt",
    },
    "PaviaUni": {
        "data_file": "PaviaUni.mat",
        "gt_file":   "PaviaU_gt.mat",
        "data_key":  "paviaU",
        "gt_key":    "paviaU_gt",
    },
    "Salinas": {
        "data_file":  "Salinas_corrected.mat",
        "gt_file":    "Salinas_gt.mat",
        "data_key":   "salinas_corrected",
        "gt_key":     "salinas_gt",
    },
    "Indian pines":{
        "data_file":  "Indian_pines.mat",
        "gt_file":    "Indian_pines_gt.mat",
        "data_key":   "indian_pines",
        "gt_key":     "indian_pines_gt",
    }
}
In [4]:
DEFAULT_DS = "Pavia"

ds = widgets.Dropdown(options=list(DATASETS.keys()),
                      value=DEFAULT_DS,
                      description="Dataset:")
out = widgets.Output()

display(ds, out)
Dropdown(description='Dataset:', options=('Pavia', 'PaviaUni', 'Salinas', 'Indian pines'), value='Pavia')
Output()
In [5]:
keys = DATASETS[ds.value]
In [6]:
keys
Out[6]:
{'data_file': 'Pavia.mat',
 'gt_file': 'Pavia_gt.mat',
 'data_key': 'pavia',
 'gt_key': 'pavia_gt'}

Hyperspectral Image Cube¶

In [7]:
def load_ds(name: str):
    """Load data for the selected dataset and put into globals()."""
    cfg = DATASETS[name]
    data_mat = sio.loadmat(os.path.join(MAT_DIR, cfg["data_file"]))
    gt_mat   = sio.loadmat(os.path.join(GT_DIR,  cfg["gt_file"]))
    globals()["data_cube"] = data_mat[cfg["data_key"]]
    globals()["gt_data"]   = gt_mat[cfg["gt_key"]]
    with out:
        out.clear_output(wait=True)
        print(f"Loaded {name}: X {data_cube.shape}, GT {gt_data.shape}")
In [8]:
def _on_change(change):
    if change["name"] == "value":
        load_ds(change["new"])
In [9]:
ds.observe(_on_change, names="value")
In [10]:
load_ds(ds.value)

Ground Truth Data¶

In [11]:
gt_data
Out[11]:
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], shape=(1096, 715), dtype=uint8)
In [12]:
gt_data.shape
Out[12]:
(1096, 715)
In [13]:
gt_labels = np.sort(np.unique(gt_data))
In [14]:
num_labels = len(gt_labels)
In [16]:
bg_indices = gt_data == 0
In [17]:
filter = mx.ones([data_cube.shape[0], data_cube.shape[1]])
In [18]:
mask = mx.where(bg_indices, 0.0, filter)
In [20]:
mask3d = mask[..., None]
In [21]:
masked_cube = mx.multiply(data_cube, mask3d)
In [25]:
masked_cube_reshaped = mx.reshape(masked_cube, (masked_cube.shape[0]*masked_cube.shape[1], masked_cube.shape[2]))
In [28]:
cube_np = np.array(masked_cube_reshaped)
nonzero_mask = np.any(cube_np != 0, axis=1)
filtered_np = cube_np[nonzero_mask]
filtered_pixels = mx.array(filtered_np)

Fit the Model¶

In [29]:
model = KSS(
    n_clusters=num_labels-1,
    subspaces_dims=2,
    max_iter=100,
    n_init=150,
    verbose=1,
    random_state=10)
In [30]:
%%time
model.fit(filtered_pixels.T)
Running KSS with n_clusters=9, subspaces_dims=[2, 2, 2, 2, 2, 2, 2, 2, 2], max_iter=100, n_init=150
 KSS run 1/150
  Run cost: 2.4692e+13
 KSS run 2/150
  Run cost: 2.4692e+13
 KSS run 3/150
  Run cost: 2.4693e+13
 KSS run 4/150
  Run cost: 2.4691e+13
 KSS run 5/150
  Run cost: 2.4688e+13
 KSS run 6/150
  Run cost: 2.4692e+13
 KSS run 7/150
  Run cost: 2.4692e+13
 KSS run 8/150
  Run cost: 2.4692e+13
 KSS run 9/150
  Run cost: 2.4689e+13
 KSS run 10/150
  Run cost: 2.4691e+13
 KSS run 11/150
  Run cost: 2.4690e+13
 KSS run 12/150
  Run cost: 2.4690e+13
 KSS run 13/150
  Run cost: 2.4689e+13
 KSS run 14/150
  Run cost: 2.4690e+13
 KSS run 15/150
  Run cost: 2.4691e+13
 KSS run 16/150
  Run cost: 2.4691e+13
 KSS run 17/150
  Run cost: 2.4693e+13
 KSS run 18/150
  Run cost: 2.4692e+13
 KSS run 19/150
  Run cost: 2.4691e+13
 KSS run 20/150
  Run cost: 2.4685e+13
 KSS run 21/150
  Run cost: 2.4688e+13
 KSS run 22/150
  Run cost: 2.4690e+13
 KSS run 23/150
  Run cost: 2.4691e+13
 KSS run 24/150
  Run cost: 2.4688e+13
 KSS run 25/150
  Run cost: 2.4691e+13
 KSS run 26/150
  Run cost: 2.4688e+13
 KSS run 27/150
  Run cost: 2.4691e+13
 KSS run 28/150
  Run cost: 2.4692e+13
 KSS run 29/150
  Run cost: 2.4690e+13
 KSS run 30/150
  Run cost: 2.4688e+13
 KSS run 31/150
  Run cost: 2.4692e+13
 KSS run 32/150
  Run cost: 2.4691e+13
 KSS run 33/150
  Run cost: 2.4688e+13
 KSS run 34/150
  Run cost: 2.4690e+13
 KSS run 35/150
  Run cost: 2.4687e+13
 KSS run 36/150
  Run cost: 2.4690e+13
 KSS run 37/150
  Run cost: 2.4688e+13
 KSS run 38/150
  Run cost: 2.4689e+13
 KSS run 39/150
  Run cost: 2.4688e+13
 KSS run 40/150
  Run cost: 2.4691e+13
 KSS run 41/150
  Run cost: 2.4690e+13
 KSS run 42/150
  Run cost: 2.4690e+13
 KSS run 43/150
  Run cost: 2.4691e+13
 KSS run 44/150
  Run cost: 2.4690e+13
 KSS run 45/150
  Run cost: 2.4691e+13
 KSS run 46/150
  Run cost: 2.4693e+13
 KSS run 47/150
  Run cost: 2.4693e+13
 KSS run 48/150
  Run cost: 2.4687e+13
 KSS run 49/150
  Run cost: 2.4689e+13
 KSS run 50/150
  Run cost: 2.4691e+13
 KSS run 51/150
  Run cost: 2.4690e+13
 KSS run 52/150
  Run cost: 2.4691e+13
 KSS run 53/150
  Run cost: 2.4691e+13
 KSS run 54/150
  Run cost: 2.4686e+13
 KSS run 55/150
  Run cost: 2.4691e+13
 KSS run 56/150
  Run cost: 2.4690e+13
 KSS run 57/150
  Run cost: 2.4690e+13
 KSS run 58/150
  Run cost: 2.4692e+13
 KSS run 59/150
  Run cost: 2.4687e+13
 KSS run 60/150
  Run cost: 2.4691e+13
 KSS run 61/150
  Run cost: 2.4690e+13
 KSS run 62/150
  Run cost: 2.4690e+13
 KSS run 63/150
  Run cost: 2.4692e+13
 KSS run 64/150
  Run cost: 2.4692e+13
 KSS run 65/150
  Run cost: 2.4689e+13
 KSS run 66/150
  Run cost: 2.4692e+13
 KSS run 67/150
  Run cost: 2.4692e+13
 KSS run 68/150
  Run cost: 2.4691e+13
 KSS run 69/150
  Run cost: 2.4691e+13
 KSS run 70/150
  Run cost: 2.4690e+13
 KSS run 71/150
  Run cost: 2.4690e+13
 KSS run 72/150
  Run cost: 2.4687e+13
 KSS run 73/150
  Run cost: 2.4692e+13
 KSS run 74/150
  Run cost: 2.4692e+13
 KSS run 75/150
  Run cost: 2.4689e+13
 KSS run 76/150
  Run cost: 2.4689e+13
 KSS run 77/150
  Run cost: 2.4685e+13
 KSS run 78/150
  Run cost: 2.4689e+13
 KSS run 79/150
  Run cost: 2.4691e+13
 KSS run 80/150
  Run cost: 2.4690e+13
 KSS run 81/150
  Run cost: 2.4690e+13
 KSS run 82/150
  Run cost: 2.4691e+13
 KSS run 83/150
  Run cost: 2.4691e+13
 KSS run 84/150
  Run cost: 2.4691e+13
 KSS run 85/150
  Run cost: 2.4691e+13
 KSS run 86/150
  Run cost: 2.4689e+13
 KSS run 87/150
  Run cost: 2.4689e+13
 KSS run 88/150
  Run cost: 2.4689e+13
 KSS run 89/150
  Run cost: 2.4691e+13
 KSS run 90/150
  Run cost: 2.4690e+13
 KSS run 91/150
  Run cost: 2.4691e+13
 KSS run 92/150
  Run cost: 2.4686e+13
 KSS run 93/150
  Run cost: 2.4690e+13
 KSS run 94/150
  Run cost: 2.4692e+13
 KSS run 95/150
  Run cost: 2.4691e+13
 KSS run 96/150
  Run cost: 2.4691e+13
 KSS run 97/150
  Run cost: 2.4684e+13
 KSS run 98/150
  Run cost: 2.4690e+13
 KSS run 99/150
  Run cost: 2.4689e+13
 KSS run 100/150
  Run cost: 2.4691e+13
 KSS run 101/150
  Run cost: 2.4688e+13
 KSS run 102/150
  Run cost: 2.4690e+13
 KSS run 103/150
  Run cost: 2.4692e+13
 KSS run 104/150
  Run cost: 2.4692e+13
 KSS run 105/150
  Run cost: 2.4691e+13
 KSS run 106/150
  Run cost: 2.4690e+13
 KSS run 107/150
  Run cost: 2.4690e+13
 KSS run 108/150
  Run cost: 2.4691e+13
 KSS run 109/150
  Run cost: 2.4691e+13
 KSS run 110/150
  Run cost: 2.4689e+13
 KSS run 111/150
  Run cost: 2.4684e+13
 KSS run 112/150
  Run cost: 2.4691e+13
 KSS run 113/150
  Run cost: 2.4688e+13
 KSS run 114/150
  Run cost: 2.4691e+13
 KSS run 115/150
  Run cost: 2.4691e+13
 KSS run 116/150
  Run cost: 2.4691e+13
 KSS run 117/150
  Run cost: 2.4691e+13
 KSS run 118/150
  Run cost: 2.4686e+13
 KSS run 119/150
  Run cost: 2.4690e+13
 KSS run 120/150
  Run cost: 2.4691e+13
 KSS run 121/150
  Run cost: 2.4691e+13
 KSS run 122/150
  Run cost: 2.4689e+13
 KSS run 123/150
  Run cost: 2.4690e+13
 KSS run 124/150
  Run cost: 2.4693e+13
 KSS run 125/150
  Run cost: 2.4686e+13
 KSS run 126/150
  Run cost: 2.4691e+13
 KSS run 127/150
  Run cost: 2.4690e+13
 KSS run 128/150
  Run cost: 2.4690e+13
 KSS run 129/150
  Run cost: 2.4692e+13
 KSS run 130/150
  Run cost: 2.4691e+13
 KSS run 131/150
  Run cost: 2.4691e+13
 KSS run 132/150
  Run cost: 2.4691e+13
 KSS run 133/150
  Run cost: 2.4691e+13
 KSS run 134/150
  Run cost: 2.4691e+13
 KSS run 135/150
  Run cost: 2.4682e+13
 KSS run 136/150
  Run cost: 2.4689e+13
 KSS run 137/150
  Run cost: 2.4690e+13
 KSS run 138/150
  Run cost: 2.4692e+13
 KSS run 139/150
  Run cost: 2.4691e+13
 KSS run 140/150
  Run cost: 2.4692e+13
 KSS run 141/150
  Run cost: 2.4691e+13
 KSS run 142/150
  Run cost: 2.4690e+13
 KSS run 143/150
  Run cost: 2.4688e+13
 KSS run 144/150
  Run cost: 2.4690e+13
 KSS run 145/150
  Run cost: 2.4686e+13
 KSS run 146/150
  Run cost: 2.4691e+13
 KSS run 147/150
  Run cost: 2.4690e+13
 KSS run 148/150
  Run cost: 2.4690e+13
 KSS run 149/150
  Run cost: 2.4691e+13
 KSS run 150/150
  Run cost: 2.4690e+13
CPU times: user 2min, sys: 1min 26s, total: 3min 26s
Wall time: 9min 27s
Out[30]:
<KSS.kss_mlx.KSS at 0x13df1f5f0>
In [31]:
labels = model.labels_
In [41]:
# labels
In [42]:
for k in range(1, num_labels):
    ilist = np.nonzero(labels == k)[0]
    print(f"Cluster {k}: {ilist}")
Cluster 1: [     4      5     16 ... 104460 104580 105652]
Cluster 2: [     0      9     10 ... 148007 148015 148135]
Cluster 3: [   137    138    139 ... 148149 148150 148151]
Cluster 4: [    63    133    201 ... 148119 148141 148143]
Cluster 5: [   210    211    212 ... 145464 147928 148142]
Cluster 6: [   323    398    630 ... 148130 148132 148148]
Cluster 7: [     1      2      3 ... 145465 145511 146777]
Cluster 8: [    67     68    931 ... 144673 144682 144891]
Cluster 9: [   135    136    204 ... 148140 148144 148145]

Creating a clustermap¶

In [34]:
label_px = gt_data > 0 # Returns a boolean map with background pixels stored as False

clustermap = np.zeros_like(gt_data, dtype=np.int32) # Create a zero filled 2D array size of gt_data

clustermap[label_px] = labels # Fills the foreground pixels
In [ ]:
 

Plot - Ground Truth Vs Cluster assignment¶

In [35]:
def plot_gt_vs_clusters(gt_data, clustermap, title_left="Ground Truth", title_right="KSS Clustering Results"):
    gt_max = int(np.max(gt_data))
    km_max = int(np.max(clustermap))
    n_classes = max(gt_max, km_max)

    # color 0 = background (black); then categorical colors for 1..n_classes
    base = plt.cm.tab20.colors  # plenty of distinct colors
    palette = ['black'] + [base[i % len(base)] for i in range(n_classes)]
    cmap = ListedColormap(palette)
    norm = BoundaryNorm(np.arange(n_classes + 2) - 0.5, cmap.N)  # hard bins at integers

    fig, axes = plt.subplots(1, 2, figsize=(9, 7), constrained_layout=True)

    im0 = axes[0].imshow(gt_data, cmap=cmap, norm=norm, interpolation='nearest')
    axes[0].set_title(title_left, fontsize=16)
    axes[0].axis('equal'); axes[0].set_xticks([]); axes[0].set_yticks([])

    im1 = axes[1].imshow(clustermap, cmap=cmap, norm=norm, interpolation='nearest')
    axes[1].set_title(title_right, fontsize=16)
    axes[1].axis('equal'); axes[1].set_xticks([]); axes[1].set_yticks([])

    # one colorbar per panel with integer ticks (0..n_classes)
    ticks = np.arange(0, n_classes + 1, 1)
    cbar0 = fig.colorbar(im0, ax=axes[0], ticks=ticks, orientation='horizontal', fraction=0.046, pad=0.05)
    cbar1 = fig.colorbar(im1, ax=axes[1], ticks=ticks, orientation='horizontal', fraction=0.046, pad=0.05)

    plt.show()
In [36]:
plot_gt_vs_clusters(gt_data, clustermap)
No description has been provided for this image
In [ ]: