import numpy as np
import os
from PIL import Image
from scipy.io import loadmat
import matplotlib.pyplot as plt

from sklearn.metrics import adjusted_rand_score
import torch
from weighted_tmbkkm  import WTKernelMiniBatchKMeans
from sklearn.cluster import MiniBatchKMeans
from scipy.optimize import linear_sum_assignment
from kernel_kmeans import initialize_kernel_kmeans_plusplus
from tqdm import tqdm



def remap_labels(y_pred, y_true):
    # ints in [0…D)
    D = max(y_pred.max(), y_true.max()) + 1
    
    # build confusion matrix
    cm = np.zeros((D, D), dtype=int)
    for p, t in zip(y_pred, y_true):
        cm[p, t] += 1

    # Hungarian to max trace → min (−cm)
    row_ind, col_ind = linear_sum_assignment(-cm)

    # create mapping array: map[pred_label] = true_label
    mapping = np.zeros(D, dtype=int)
    mapping[row_ind] = col_ind

    # apply it to all predictions
    y_pred_remapped = mapping[y_pred]

    return y_pred_remapped, mapping


def gpu_pairwise_kernels_batch(X, Y=None, metric='rbf', gamma=None, batch_size=1024):
    """
    Compute the pairwise kernel on GPU using PyTorch in batches.
    
    Args:
        X (np.ndarray): Input data array.
        Y (np.ndarray or None): Optional second input data array. If None, compute the kernel with X itself.
        metric (str): The kernel metric to use. Currently only 'rbf' is supported.
        gamma (float or None): Kernel coefficient for 'rbf'. If None, it defaults to 1 / n_features.
        batch_size (int): The batch size for GPU computation.

    Returns:
        np.ndarray: Kernel matrix.
    """
    device = torch.device('mps') if torch.backends.mps.is_available() else ( torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu'))

    with torch.no_grad():

        X = torch.tensor(X, dtype=torch.float32, device=device)
        if Y is not None:
            Y = torch.tensor(Y, device=device, dtype=torch.float32)
        if Y is None:
            Y = X

        if metric != 'rbf':
            raise ValueError(f"Unsupported metric '{metric}'. Currently only 'rbf' is supported.")
        
        if gamma is None:
            gamma = 1.0 / X.shape[1]

        n_samples_X = X.shape[0]
        n_samples_Y = Y.shape[0]

        # Initialize the kernel matrix
        K = np.zeros((n_samples_X, n_samples_Y), dtype=np.float32)

        # # initialize K on the GPU
        # K = torch.tensor(K, device=device)

        # Process in batches
        for i in range(0, n_samples_X, batch_size):
            end_i = min(i + batch_size, n_samples_X)
            X_batch = X[i:end_i]

            X_norm = (X_batch ** 2).sum(axis=-1).view(-1, 1)
            for j in range(0, n_samples_Y, batch_size):
                end_j = min(j + batch_size, n_samples_Y)
                Y_batch = Y[j:end_j]

                Y_norm = (Y_batch ** 2).sum(axis=-1).view(1, -1)
                K_batch = torch.exp(-gamma * (X_norm + Y_norm - 2 * torch.mm(X_batch, Y_batch.T)))

                K[i:end_i, j:end_j] = K_batch.cpu().numpy()
        
        # Copy the kernel matrix back to the CPU
        # K = K.cpu().numpy()
        # clean up
        del X
        del Y
        # torch.cuda.empty_cache() if device.type == 'cuda' else None

        return K
    
def gpu_pairwise_weighted_jaccard_batch(X, Y=None, batch_size=1024):
    """
    Compute the pairwise weighted Jaccard kernel on GPU using PyTorch in batches.

    Weighted Jaccard similarity between two non‐negative vectors x, y is:
        J(x, y) = sum_i min(x_i, y_i) / sum_i max(x_i, y_i)

    Args:
        X (np.ndarray): Input data array, shape (n_samples_X, n_features), non-negative.
        Y (np.ndarray or None): Optional second input data array, shape (n_samples_Y, n_features).
                                If None, compute the kernel with X itself.
        batch_size (int): The batch size for GPU computation.

    Returns:
        np.ndarray: Kernel matrix of shape (n_samples_X, n_samples_Y), dtype float32.
    """
    # select GPU/TPU/CPU
    device = torch.device('mps') if torch.backends.mps.is_available() else (
             torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

    # move data to torch tensors on the device
    X = torch.tensor(X, dtype=torch.float32, device=device)
    if Y is not None:
        Y = torch.tensor(Y, dtype=torch.float32, device=device)
    else:
        Y = X

    # check non-negativity
    if (X < 0).any() or (Y < 0).any():
        raise ValueError("Weighted Jaccard requires non-negative data.")

    n_samples_X, n_features = X.shape
    n_samples_Y = Y.shape[0]

    # prepare output on CPU
    K = np.zeros((n_samples_X, n_samples_Y), dtype=np.float32)

    # batched computation
    with torch.no_grad():
        for i in range(0, n_samples_X, batch_size):
            end_i = min(i + batch_size, n_samples_X)
            X_batch = X[i:end_i]                         # (batch_i, n_features)
            # reshape for broadcasting: (batch_i, 1, n_features)
            Xb = X_batch.unsqueeze(1)

            for j in range(0, n_samples_Y, batch_size):
                end_j = min(j + batch_size, n_samples_Y)
                Y_batch = Y[j:end_j]                     # (batch_j, n_features)
                Yb = Y_batch.unsqueeze(0)                # (1, batch_j, n_features)

                # compute elementwise min and max, then sum over features
                mins = torch.min(Xb, Yb).sum(dim=2)      # (batch_i, batch_j)
                maxs = torch.max(Xb, Yb).sum(dim=2)      # (batch_i, batch_j)

                # avoid div-zero
                jacc = mins / (maxs + 1e-12)

                K[i:end_i, j:end_j] = jacc.cpu().numpy()

    # cleanup
    del X, Y
    if device.type == 'cuda':
        torch.cuda.empty_cache()

    return K


def load_bsds500_with_segs(data_root = "BSDS500/BSDS500", split="train", size=(256,256)):
    """
    Load BSDS500 images, boundary‐maps, and segmentation maps.
    Returns a list of (X, bmap, segs) where
      - X    : uint8 array, shape (H, W, 3)
      - bmap : uint8 array, shape (H, W)
      - segs : list of np.int32 arrays, each shape (H, W), the region labels
    """
    imgs_dir = os.path.join(data_root, "data", "images", split)
    gt_dir   = os.path.join(data_root, "data", "groundTruth", split)

    out = []

    for fname in sorted(os.listdir(imgs_dir)):
        if not fname.lower().endswith(".jpg"):
            continue
        name = os.path.splitext(fname)[0]

        # ——— load & resize RGB image ———
        img = Image.open(os.path.join(imgs_dir, fname)).convert("RGB")
        img = img.resize(size, resample=Image.LANCZOS)
        X = np.array(img, dtype=np.uint8)   # (H, W, 3)

        # ——— load the .mat and extract all annotators ———
        m = loadmat(os.path.join(gt_dir, f"{name}.mat"))
        gts = m["groundTruth"][0]           # array of annotator‐structs

        # ——— build boundary map (average) ———
        bmaps = [gt["Boundaries"][0][0] for gt in gts]
        b = np.mean(np.stack(bmaps,0),0)
        b = (b / b.max() * 255).astype(np.uint8)
        b = Image.fromarray(b).resize(size, resample=Image.NEAREST)
        bmap = np.array(b, dtype=np.uint8)  # (H, W)

        # ——— build segmentation maps list ———
        segs = []
        for gt in gts:
            seg = gt["Segmentation"][0][0]     # an int32 array, original size
            # turn into PIL then resize with nearest
            seg = Image.fromarray(seg.astype(np.int32))
            seg = seg.resize(size, resample=Image.NEAREST)

            # convert indices from 1..K to 0..K-1
            seg = np.array(seg, dtype=np.int32) - 1
            segs.append(np.array(seg, dtype=np.int32))  # (H, W)

        out.append((X, bmap, segs))

    return out

def extract_vectors_and_labels(X, seg, coord_scale = 1.0):
    """
    Given:
      X   : uint8 RGB image of shape (H, W, 3)
      seg : int32 segmentation mask of shape (H, W)
    Returns:
      flattened_X : float32 array of shape (H*W, 5)  – [R,G,B,x_norm,y_norm]
      seg_flat    : int32 array of shape (H*W,)      – region labels
    """
    # normalize colors
    X = X.astype(np.float32) / 255.0
    H, W, C = X.shape

    # flatten RGB into (H*W, 3)
    X_flat = X.reshape(H * W, C)                     # (N, 3)

    # flatten segmentation into (H*W,)
    seg_flat = seg.astype(np.int32).reshape(H * W)   # (N,)

    # build normalized coordinate arrays, shape (H, W)
    ys, xs = np.indices((H, W))
    x_norm = coord_scale *xs.astype(np.float32) / (W - 1)
    y_norm = coord_scale *ys.astype(np.float32) / (H - 1)

    # flatten coords into (H*W, 1) each
    x_flat = x_norm.reshape(H * W, 1)
    y_flat = y_norm.reshape(H * W, 1)

    # concatenate to (H*W, 5)
    flattened_X = np.concatenate((X_flat, x_flat, y_flat), axis=1)
    return flattened_X, seg_flat



def reconstruct_mask(flat_seg, H, W):
    assert flat_seg.size == H * W, "Size mismatch: flat_seg must have length H*W"
    seg_mask = flat_seg.reshape(H, W)
    return seg_mask

def minibatch_kmeans(X, num_clusters, batch_size=1024, n_iterations=100):
    """
    Perform MiniBatch KMeans clustering on the input data.
    
    Args:
        X (np.ndarray): Input data array.
        num_clusters (int): Number of clusters.
        batch_size (int): Size of each mini-batch.
        n_iterations (int): Number of iterations for KMeans.

    Returns:
        np.ndarray: Cluster labels for each data point.
    """
    kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=batch_size, max_iter=n_iterations, random_state = 0)
    kmeans.fit(X)
    return kmeans.labels_

def wtruncated_kkmneans(X, num_clusters, batch_size=1024, n_iterations=100, tau= 100, gamma=0.5, shift=0.0, jaccard=False):

    if not jaccard:
        K = gpu_pairwise_kernels_batch(X, metric='rbf', gamma=gamma)
    else:
        K = gpu_pairwise_weighted_jaccard_batch(X, batch_size=batch_size)

    if shift > 0:
        K = K + shift * np.eye(K.shape[0])
    
    km = WTKernelMiniBatchKMeans(n_clusters=num_clusters, batch_size=batch_size, n_iterations=n_iterations, tau=tau, lazy=False, random_state=0)

    rng = np.random.RandomState(0)

    init_labels, init_distances_squared, init_C = initialize_kernel_kmeans_plusplus(K,num_clusters,rng)
    km.fit(X, K, init_labels, init_distances_squared, init_C)

    return km.labels_


def gen_plot(axes,
        plt_idx,
        idx,
        data_root,
        size,
        batch_size,
        num_iters,
        tau,
        gamma,
        coord_scale,
        shift):
    train = load_bsds500_with_segs(data_root, split="train", size=(size,size))
    X0, b0, segs0 = train[idx]

    #extract the vectors and labels
    flattened_X, seg = extract_vectors_and_labels(X0, segs0[0], coord_scale=coord_scale)
    num_clusters = len(np.unique(seg))
    H, W, _ = X0.shape

    # run kmeans clustering:
    km_flat_seg = minibatch_kmeans(flattened_X, num_clusters=num_clusters, batch_size=batch_size, n_iterations=num_iters)
    km_flat_seg, _ = remap_labels(km_flat_seg, seg)

    km_ari = adjusted_rand_score(km_flat_seg, seg)
    print(f"mini-batch k-means ARI: {km_ari:.3f}")

    # reconstruct the segmentation mask
    km_seg_mask = reconstruct_mask(km_flat_seg, H, W)

    # run weighted truncated k-means clustering:
    wtkm_flat_seg = wtruncated_kkmneans(flattened_X, num_clusters=num_clusters, batch_size=batch_size, n_iterations=num_iters, tau=tau, gamma=gamma, shift=shift)
    wtkm_flat_seg, _ = remap_labels(wtkm_flat_seg, seg)
    kkm_ari = adjusted_rand_score(wtkm_flat_seg, seg)
    print(f"weighted truncated k-means ARI: {kkm_ari:.3f}")
    # reconstruct the segmentation mask
    wtkm_seg_mask = reconstruct_mask(wtkm_flat_seg, H, W)


    # axes[i,0] for the image
    # axes[i,1] for the ground truth segmentation
    # axes[i,2] for the kmeans segmentation
    # axes[i,3] for the wtkmeans segmentation

    axes[plt_idx,0].imshow(X0)
    axes[plt_idx,0].set_title("Image")
    # axes[plt_idx,1].imshow(segs0[0], cmap='nipy_spectral')
    # axes[plt_idx,1].set_title("Ground truth segmentation")
    axes[plt_idx,1].imshow(km_seg_mask, cmap='nipy_spectral')
    axes[plt_idx,1].set_title(f"Mini-batch k-means\nARI: {km_ari:.3f}")
    axes[plt_idx,2].imshow(wtkm_seg_mask, cmap='nipy_spectral')
    axes[plt_idx,2].set_title(f"Truncated Mini-batch Kernel k-means\nARI: {kkm_ari:.3f}")
    # turn off all axes and ticks
    for a in axes[plt_idx]:
        a.axis('off')


def append_ari(file_name,
        train,
        idx,
        batch_size,
        num_iters,
        tau,
        gamma,
        coord_scale,
        shift,
        jaccard):
    X0, b0, segs0 = train[idx]

    #extract the vectors and labels
    flattened_X, seg = extract_vectors_and_labels(X0, segs0[0], coord_scale=coord_scale)
    num_clusters = len(np.unique(seg))
    H, W, _ = X0.shape

    # run kmeans clustering:
    km_flat_seg = minibatch_kmeans(flattened_X, num_clusters=num_clusters, batch_size=batch_size, n_iterations=num_iters)
    km_flat_seg, _ = remap_labels(km_flat_seg, seg)

    km_ari = adjusted_rand_score(km_flat_seg, seg)

    # reconstruct the segmentation mask
    # km_seg_mask = reconstruct_mask(km_flat_seg, H, W)

    # run weighted truncated k-means clustering:
    wtkm_flat_seg = wtruncated_kkmneans(flattened_X, num_clusters=num_clusters, batch_size=batch_size, n_iterations=num_iters, tau=tau, gamma=gamma, shift=shift, jaccard=jaccard)
    wtkm_flat_seg, _ = remap_labels(wtkm_flat_seg, seg)
    kkm_ari = adjusted_rand_score(wtkm_flat_seg, seg)
    # reconstruct the segmentation mask
    # wtkm_seg_mask = reconstruct_mask(wtkm_flat_seg, H, W)


    # append the index, km_ari, and kkm_ari to the file (csv)

    with open(file_name, 'a') as f:
        f.write(f"{idx},{km_ari:.3f},{kkm_ari:.3f}\n")


def process_all_images(
        data_root,
        file_name,
        size,
        batch_size,
        num_iters,
        tau,
        gamma,
        coord_scale,
        shift,
        jaccard=False
    ):


    train = load_bsds500_with_segs(data_root, split="train", size=(size,size))

    current_results = []
    if os.path.exists(file_name):
        with open(file_name, 'r') as f:
            # skip the header
            next(f)
            for line in f:
                current_results.append(line.strip().split(','))

    already_processed = set()
    for result in current_results:
        already_processed.add(int(result[0]))
    
    to_process = set(range(len(train))) - already_processed

    for i in tqdm(to_process):
        append_ari(file_name, train, i, batch_size, num_iters, tau, gamma, coord_scale, shift, jaccard)

# Example usage:
if __name__ == "__main__":

    data_root = "BSDS500/BSDS500"
    fig, axes = plt.subplots(3, 3, figsize=(16, 12))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    # Example parameters for different images

    #tiger:
    gen_plot(axes, 0, 10, data_root, size=256, batch_size=1024, num_iters=300, tau=256, gamma=0.1, coord_scale=0.15, shift=0.0)
    #church:
    gen_plot(axes, 1, 18, data_root, size=256, batch_size=1024, num_iters=300, tau=256, gamma=0.25, coord_scale=0.25, shift=0.0)
    #bird:
    gen_plot(axes, 2, 28, data_root, size=256, batch_size=1024, num_iters=300, tau=150, gamma=0.1, coord_scale=0.15, shift=0.0)
    plt.tight_layout()
    fig.savefig("bsds.png", dpi=300)




    # file_name = "bsds_full_results.csv"
    # if not os.path.exists(file_name):
    #     with open(file_name, 'w') as f:
    #         f.write("index,km_ari,kkm_ari\n")

    # process_all_images(
    #     data_root = "BSDS500/BSDS500",
    #     file_name = file_name,
    #     size = 256,
    #     batch_size = 1024,
    #     num_iters = 300,
    #     tau = 256,
    #     gamma = 0.3,
    #     coord_scale = 0.25,
    #     shift = 0.0,
    #     jaccard = True
    # )

    # # load and compute mean ARIs
    # km_ari = []
    # kkm_ari = []

    # with open(file_name, 'r') as f:
    #     # skip the header
    #     next(f)
    #     for line in f:
    #         index, km, kkm = line.strip().split(',')
    #         km_ari.append(float(km))
    #         kkm_ari.append(float(kkm))
        
    # km_ari = np.array(km_ari)
    # kkm_ari = np.array(kkm_ari)

    # print(f"Mean ARI for k-means: {km_ari.mean():.3f} ± {km_ari.std():.3f}")
    # print(f"Mean ARI for kernel k-means: {kkm_ari.mean():.3f} ± {kkm_ari.std():.3f}")
    
    # gamma of 0.25 gives 
    # Mean ARI for k-means: 0.202 ± 0.122
    # Mean ARI for k-means: 0.197 ± 0.139

