# %%
from collections import OrderedDict
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

# plt.style.use("dark_background")
# %%
# km_save_dir = '/data/script_tiny/xaaa/'
# subject_data_dir = '/data/VWET/'
# roi_prefix = 'vwtiny'
# path = "/data/results/xaaa/crn_tiny/stage_2/soup.pth"
km_save_dir = "/data/script_small/xaaa_mkii/"
subject_data_dir = "/data/VWET/"
roi_prefix = "vw"
hroi_prefix = "htroi"
path = "/data/results/xaaa/crn_small/stage_2/soup.pth"
# %%
w = torch.load(path, map_location="cpu")
m = "voxel_out"
matches = [k for k in w.keys() if m in k and "weight" in k]
subject_ids = [k.split(".")[2] for k in matches]
voxel_outs = [w[k] for k in matches]
lengths = [v.shape[0] for v in voxel_outs]
voxel_outs = torch.cat(voxel_outs, dim=0)
voxel_outs = voxel_outs.squeeze(-1)
print(voxel_outs.shape)
print(subject_ids)
print(lengths)
os.makedirs(km_save_dir, exist_ok=True)
torch.save(voxel_outs, os.path.join(km_save_dir, "voxel_outs.pth"))
# %%
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.functional import pairwise_cosine_similarity


# %%
random_indices = torch.randint(0, voxel_outs.shape[0], (1000,))
c = pairwise_cosine_similarity(voxel_outs[random_indices])
sns.clustermap(c, method="ward")
plt.show()


# def kmeans_cluster(voxel_outs, n_clusters=1000):
#     from sklearn.cluster import KMeans
#     kmeans = KMeans(n_clusters=n_clusters, random_state=0, verbose=True).fit(voxel_outs)
#     return kmeans
# # %%
# km = kmeans_cluster(voxel_outs)
# %%
def gpu_kmeans_cluster(voxel_outs, n_clusters=100):
    from fast_pytorch_kmeans import KMeans

    kmeans = KMeans(
        n_clusters=n_clusters, verbose=True, 
        mode="cosine", 
        max_iter=1000, tol=1e-6
    )
    labels = kmeans.fit_predict(voxel_outs)
    return kmeans, labels


# %%
# for k in [100, 200, 500, 1000, 2000, 4000]:
k = 256
K = k
km, labels = gpu_kmeans_cluster(voxel_outs.to("cuda"), n_clusters=k)
km.centroids.shape, labels.shape
# plt.hist(labels.cpu(), bins=100)
# plt.show()
from torchmetrics.functional import pairwise_cosine_similarity, pairwise_euclidean_distance

c = km.centroids
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
# d = pairwise_euclidean_distance(c)
d = d.cpu().numpy()
sns.clustermap(d, 
               method="ward",
            )
plt.title(f"K={k}")
plt.show()
# %%
os.makedirs(km_save_dir, exist_ok=True)
np.save(os.path.join(km_save_dir, f"centroids.npy"), c.cpu().numpy())
np.save(os.path.join(km_save_dir, f"labels.npy"), labels.cpu().numpy())
np.save(os.path.join(km_save_dir, f"subject_ids.npy"), subject_ids)
np.save(os.path.join(km_save_dir, f"lengths.npy"), lengths)
# %%
c = np.load(os.path.join(km_save_dir, f"centroids.npy"))
labels = np.load(os.path.join(km_save_dir, f"labels.npy"))
subject_ids = np.load(os.path.join(km_save_dir, f"subject_ids.npy"))
lengths = np.load(os.path.join(km_save_dir, f"lengths.npy"))
# %%
labels.shape
# %%
labels = labels.cpu().numpy()
# %%
_, counts = np.unique(labels, return_counts=True)
plt.hist(counts, bins=100)
plt.show()
# %%
sums = []
start = 0
rois = {}
for subject_id, length in zip(subject_ids, lengths):
    rois[subject_id] = {}
    end = start + length
    i_labels = labels[start:end]
    start += length
    for i_k in np.unique(i_labels):
        # print(f"{subject_id} {i_k} {np.sum(i_labels == i_k)}")
        sums.append(np.sum(i_labels == i_k))
        rois[subject_id][i_k] = np.where(i_labels == i_k)[0]
# %%
plt.hist(sums, bins=100)
plt.show()
# %%
min_lenght = 100
for subject_id in rois.keys():
    save_dir = os.path.join(subject_data_dir, f"{subject_id}/roi")
    os.makedirs(save_dir, exist_ok=True)
    for i_k in range(K):
        path = os.path.join(save_dir, f"{roi_prefix}_{i_k}.npy")
        if i_k in rois[subject_id]:
            indices = rois[subject_id][i_k]
        else:
            indices = []
            indices = np.array(indices)
        if len(indices) < min_lenght:
            indices = np.array([])
        np.save(path, indices)
# # %%
# non_empty_ks = []
# for i_k in range(K):
#     non_empty_ks.append(i_k)
#     for subject_id in rois.keys():
#         if i_k not in rois[subject_id]:
#             non_empty_ks.pop()
#             break
#         if len(rois[subject_id][i_k]) < 1:
#             non_empty_ks.pop()
#             break
# # %%
# non_empty_ks
# # %%
# for k, v in rois["NSD_04"].items():
#     print(k, len(v))
# # %%

# %%
(np.array(sums) < 100).sum() / len(sums)
# %%
from torchmetrics.functional import pairwise_cosine_similarity

c = np.load(os.path.join(km_save_dir, f"centroids.npy"))
c = torch.from_numpy(c)
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
d = d.cpu().numpy()
sns.clustermap(d, method="ward")
plt.title(f"K={k}")
plt.show()
# %%


from torchmetrics.functional import pairwise_cosine_similarity, pairwise_euclidean_distance
import seaborn as sns

# km_centroids = km.centroids
# km_labels = labels

km_centroids = np.load(os.path.join(km_save_dir, f"centroids.npy"))
km_labels = np.load(os.path.join(km_save_dir, f"labels.npy"))

plt.style.use("seaborn-whitegrid")
c = km_centroids
c = torch.from_numpy(c)
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
# d = pairwise_euclidean_distance(c)
d = d.cpu().numpy()
g = sns.clustermap(
    # c.cpu().numpy(),
    d,
    figsize=(32, 32),
    method="ward",
    # method="weighted",
    # metric="euclidean",
    # metric="cosine",
    yticklabels=1,
    xticklabels=1,
    # cmap="coolwarm",
)
xtick_labels = [int(tick.get_text()) for tick in g.ax_heatmap.get_xticklabels()]
g.ax_heatmap.tick_params(axis="y", labelleft=True, labelright=True)
# print(xtick_labels)
unique, counts = np.unique(km_labels, return_counts=True)
reordered_counts = counts[xtick_labels]
print(reordered_counts)
plt.show()
fig, ax = plt.subplots(figsize=(32, 1))
sns.heatmap(
    reordered_counts.reshape(1, -1), cbar=False, cmap="Reds", xticklabels=xtick_labels
)
plt.show()
# %%
split_list = list(
    map(
        int,
        input("Enter the split breakpoints, hint: inspect the clustermap plot: ")
        .strip()
        .split(),
    )
)
# %%
# 208 186 193 165 123 196 175 82 130 228 185 
# 208 193 165 123 196 175 82 130 228 185
# 208 193 165 123 211 196 175 82 130 228 185
# 119 20 49 37 116 108 118 92
print(split_list)
# %%
km_labels = torch.from_numpy(km_labels) if isinstance(km_labels, np.ndarray) else km_labels

start = 0
vi_dict = {}
kvi_dict = {}
for i, end in enumerate(split_list):
    i += 1

    labels = xtick_labels[start : xtick_labels.index(end) + 1]
    start += len(labels)

    cluster_voxel_indices = []
    for l in labels:
        voxel_indices = (km_labels == l).nonzero().flatten()
        cluster_voxel_indices.append(voxel_indices)
    cluster_voxel_indices = torch.cat(cluster_voxel_indices)
    cluster_voxel_indices.sort()
    cluster_voxel_indices = cluster_voxel_indices.cpu()
    vi_dict[i] = cluster_voxel_indices
    kvi_dict[i] = labels
    print(f"cluster {i}", labels)
    print(cluster_voxel_indices.shape)
# %%
labels = list(vi_dict.keys())

sums = []
start = 0
rois = {}
for subject_id, length in zip(subject_ids, lengths):
    rois[subject_id] = {}
    end = start + length
    for i_k, vi in vi_dict.items():
        sub_vi = vi[start <= vi] # this will lose the starting voxel
        sub_vi = sub_vi[sub_vi < end]
        rois[subject_id][i_k] = sub_vi - start
        sums.append(len(sub_vi))
    start += length
# %%
min_length = 10
for subject_id in rois.keys():
    save_dir = os.path.join(subject_data_dir, f"{subject_id}/roi")
    os.makedirs(save_dir, exist_ok=True)
    for i_k in vi_dict.keys():
        path = os.path.join(save_dir, f"{hroi_prefix}_{i_k}.npy")
        if i_k in rois[subject_id]:
            indices = rois[subject_id][i_k]
        else:
            indices = []
            indices = np.array(indices)
        if len(indices) < min_length:
            indices = np.array([])
        np.save(path, indices)
# %%
mat = np.zeros((len(subject_ids), len(vi_dict)))
for i in range(len(subject_ids)):
    for j in range(len(vi_dict)):
        if j+1 in rois[subject_ids[i]]:
            mat[i, j] = len(rois[subject_ids[i]][j + 1])
        else:
            mat[i, j] = 0
# %%
fig = plt.figure(figsize=(20, 10))
sns.heatmap(
    mat,
    cmap="Reds",
    xticklabels=list(vi_dict.keys()),
    yticklabels=subject_ids,
    annot=True,
    fmt=".0f",
    vmax=10000,
)
plt.show()
# %%
