# %%
from collections import OrderedDict
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

# plt.style.use("dark_background")
# %%
# km_save_dir = '/data/script_tiny/xaaa/'
# subject_data_dir = '/data/VWET/'
# roi_prefix = 'vwtiny'
# path = "/data/results/xaaa/crn_tiny/stage_2/soup.pth"
km_save_dir = "/data/script_base/xdabc_mkii/"
subject_data_dir = "/data/VWET/"
roi_prefix = "weight"
hroi_prefix = "veroi"

weights_dir = "/data/results/xdabb/dino_bb/big_model/"
weights_fname = "soup.pth"
import glob

weight_paths = glob.glob(os.path.join(weights_dir, "**", weights_fname), recursive=True)
# %%
all_run_voxel_outs = []
for path in weight_paths:
    w = torch.load(path, map_location="cpu")
    m = "voxel_out"
    matches = [k for k in w.keys() if m in k and "weight" in k]
    subject_ids = [k.split(".")[2] for k in matches]
    voxel_outs = [w[k] for k in matches]
    lengths = [v.shape[0] for v in voxel_outs]
    voxel_outs = torch.cat(voxel_outs, dim=0)
    voxel_outs = voxel_outs.squeeze(-1)
    all_run_voxel_outs.append(voxel_outs)
    print(voxel_outs.shape)
    print(subject_ids)
    print(lengths)
all_run_voxel_outs = torch.stack(all_run_voxel_outs, dim=0)
# os.makedirs(km_save_dir, exist_ok=True)
# torch.save(voxel_outs, os.path.join(km_save_dir, "voxel_outs.pth"))
# %%
voxel_outs = all_run_voxel_outs.mean(dim=0)
print(voxel_outs.shape)
exit()
# %%
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.functional import pairwise_cosine_similarity


# %%
random_indices = torch.randint(0, voxel_outs.shape[0], (1000,))
c = pairwise_cosine_similarity(voxel_outs[random_indices])
sns.clustermap(c, method="ward")
plt.show()


# def kmeans_cluster(voxel_outs, n_clusters=1000):
#     from sklearn.cluster import KMeans
#     kmeans = KMeans(n_clusters=n_clusters, random_state=0, verbose=True).fit(voxel_outs)
#     return kmeans
# # %%
# km = kmeans_cluster(voxel_outs)
# %%
def gpu_kmeans_cluster(voxel_outs, n_clusters=100):
    from fast_pytorch_kmeans import KMeans

    kmeans = KMeans(
        n_clusters=n_clusters, verbose=True, mode="cosine", max_iter=1000, tol=1e-6
    )
    labels = kmeans.fit_predict(voxel_outs)
    return kmeans, labels


# %%
# for k in [100, 200, 500, 1000, 2000, 4000]:
k = 1000
K = k
km, labels = gpu_kmeans_cluster(voxel_outs.to("cuda"), n_clusters=k)
km.centroids.shape, labels.shape
# plt.hist(labels.cpu(), bins=100)
# plt.show()
from torchmetrics.functional import (
    pairwise_cosine_similarity,
    pairwise_euclidean_distance,
)

c = km.centroids
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
# d = pairwise_euclidean_distance(c)
d = d.cpu().numpy()
sns.clustermap(
    d,
    method="ward",
)
plt.title(f"K={k}")
plt.show()
# %%
# os.makedirs(km_save_dir, exist_ok=True)
# np.save(os.path.join(km_save_dir, f"centroids.npy"), c.cpu().numpy())
# np.save(os.path.join(km_save_dir, f"labels.npy"), labels.cpu().numpy())
# np.save(os.path.join(km_save_dir, f"subject_ids.npy"), subject_ids)
# np.save(os.path.join(km_save_dir, f"lengths.npy"), lengths)
# %%
c = np.load(os.path.join(km_save_dir, f"centroids.npy"))
labels = np.load(os.path.join(km_save_dir, f"labels.npy"))
subject_ids = np.load(os.path.join(km_save_dir, f"subject_ids.npy"))
lengths = np.load(os.path.join(km_save_dir, f"lengths.npy"))
# %%
labels.shape
# %%
_, counts = np.unique(labels, return_counts=True)
plt.hist(counts, bins=100)
plt.show()
print(min(counts))
# %%
sums = []
start = 0
rois = {}
for subject_id, length in zip(subject_ids, lengths):
    rois[subject_id] = {}
    end = start + length
    i_labels = labels[start:end]
    start += length
    for i_k in np.unique(i_labels):
        # print(f"{subject_id} {i_k} {np.sum(i_labels == i_k)}")
        sums.append(np.sum(i_labels == i_k))
        rois[subject_id][i_k] = np.where(i_labels == i_k)[0]
# %%
plt.hist(sums, bins=100)
plt.show()
# %%
(sums < 100).sum() / len(sums)
# %%
min_lenght = 100
for subject_id in rois.keys():
    save_dir = os.path.join(subject_data_dir, f"{subject_id}/roi")
    os.makedirs(save_dir, exist_ok=True)
    for i_k in range(K):
        path = os.path.join(save_dir, f"{roi_prefix}_{i_k}.npy")
        if i_k in rois[subject_id]:
            indices = rois[subject_id][i_k]
        else:
            indices = []
            indices = np.array(indices)
        if len(indices) < min_lenght:
            indices = np.array([])
        np.save(path, indices)
# # %%
# non_empty_ks = []
# for i_k in range(K):
#     non_empty_ks.append(i_k)
#     for subject_id in rois.keys():
#         if i_k not in rois[subject_id]:
#             non_empty_ks.pop()
#             break
#         if len(rois[subject_id][i_k]) < 1:
#             non_empty_ks.pop()
#             break
# # %%
# non_empty_ks
# # %%
# for k, v in rois["NSD_04"].items():
#     print(k, len(v))
# # %%

# %%
(np.array(sums) < 10).sum() / len(sums)
# %%
from torchmetrics.functional import pairwise_cosine_similarity

# %%
c = np.load(os.path.join(km_save_dir, f"centroids.npy"))
c = torch.from_numpy(c)
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
d = d.cpu().numpy()
sns.clustermap(d, method="ward")
plt.title(f"K={k}")
plt.show()
# %%


from torchmetrics.functional import (
    pairwise_cosine_similarity,
    pairwise_euclidean_distance,
)
import seaborn as sns

# km_centroids = km.centroids
# km_labels = labels

km_centroids = np.load(os.path.join(km_save_dir, f"centroids.npy"))
km_labels = np.load(os.path.join(km_save_dir, f"labels.npy"))

plt.style.use("seaborn-whitegrid")
c = km_centroids
c = torch.from_numpy(c)
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
# d = pairwise_euclidean_distance(c)
d = d.cpu().numpy()
g = sns.clustermap(
    # c.cpu().numpy(),
    d,
    figsize=(16, 16),
    method="ward",
    # method="weighted",
    # metric="euclidean",
    # metric="cosine",
    yticklabels=0,
    xticklabels=0,
    row_cluster=True,
    col_cluster=True,
    # cmap="coolwarm",
)
# xtick_labels = [int(tick.get_text()) for tick in g.ax_heatmap.get_xticklabels()]
# g.ax_heatmap.tick_params(axis="y", labelleft=True, labelright=True)
# print(xtick_labels)
# unique, counts = np.unique(km_labels, return_counts=True)
# reordered_counts = counts[xtick_labels]
# print(reordered_counts)

# disable dendrogram
g.ax_row_dendrogram.set_visible(False)
g.ax_col_dendrogram.set_visible(False)
# disable colorbar
g.cax.set_visible(False)

# mat_path = os.path.join(f"/tmp/mat.pdf")
# plt.savefig(mat_path, bbox_inches="tight", pad_inches=0)
mat_path = os.path.join(f"/tmp/mat.png")
plt.savefig(mat_path, bbox_inches="tight", pad_inches=0, dpi=300)
plt.show()
# fig, ax = plt.subplots(figsize=(32, 1))
# sns.heatmap(
#     reordered_counts.reshape(1, -1), cbar=False, cmap="Reds", xticklabels=xtick_labels
# )
# plt.show()
# %%
plt.style.use("seaborn-white")
import scipy.cluster.hierarchy as shc

Z = shc.linkage(d, method="ward", optimal_ordering=False)
# %%
# fig, ax = plt.subplots(figsize=(10, 10))
roi_num_dict = {}
max_dist = 27.25
hroi_prefix = "veroi"  # 18
# max_dist = 40
# hroi_prefix = "veroi_l"
# max_dist = 18
# hroi_prefix = "veroi_s"
# max_dist = 10
# hroi_prefix = "veroi_ss"
# max_dist = 6
# hroi_prefix = "veroi_sss"  # 268
# max_dist = 1
# hroi_prefix = "veroi_extreme" # 1000

# sample 18 colors from gist_rainbow
cmap = plt.cm.get_cmap("rainbow", 19)
colors = [cmap(i) for i in range(19)]
dn_labels = shc.fcluster(Z, max_dist, criterion="distance")


# use 18 colors to plot dendrogram
def rgb_hex(color):
    """converts a (r,g,b) color (either 0-1 or 0-255) to its hex representation.
    for ambiguous pure combinations of 0s and 1s e,g, (0,0,1), (1/1/1) is assumed."""
    message = "color must be an iterable of length 3."
    assert hasattr(color, "__iter__"), message
    assert len(color) == 3, message
    if all([(c <= 1) & (c >= 0) for c in color]):
        color = [int(round(c * 255)) for c in color]  # in case provided rgb is 0-1
    color = tuple(color)
    return "#%02x%02x%02x" % color


def get_cluster_colors(
    n_clusters, my_set_of_18_rgb_colors, alpha=0.8, alpha_outliers=0.05
):
    cluster_colors = my_set_of_18_rgb_colors
    # cluster_colors = [c + [alpha] for c in cluster_colors]
    # outlier_color = [0, 0, 0, alpha_outliers]
    outlier_color = my_set_of_18_rgb_colors[-1]
    return [cluster_colors[i % 20] for i in range(n_clusters)] + [outlier_color]
    # return [cluster_colors[i % 18] for i in range(n_clusters)] + [outlier_color]

labels_str = [
    f"cluster #{l}: n={c}\n" for (l, c) in zip(*np.unique(dn_labels, return_counts=True))
]
n_clusters = len(labels_str)

cluster_colors = [
    rgb_hex(c[:-1])
    for c in get_cluster_colors(n_clusters, colors, alpha=0.8, alpha_outliers=0.05)
]
cluster_colors_array = [cluster_colors[l] for l in dn_labels]
link_cols = {}
for i, i12 in enumerate(Z[:, :2].astype(int)):
    c1, c2 = (link_cols[x] if x > len(Z) else cluster_colors_array[x] for x in i12)
    link_cols[i + 1 + len(Z)] = c1 if c1 == c2 else "k"

# plot dendrogram with colored clusters
fig = plt.figure(figsize=(16, 4))
# plt.title("Hierarchical Clustering Dendrogram")
# plt.xlabel("Data points")
# plt.ylabel("Distance")

# plot dendrogram based on clustering results
shc.dendrogram(
    Z,
    labels=dn_labels,
    color_threshold=max_dist,
    # truncate_mode="level",
    # p=1,
    # show_leaf_counts=True,
    # leaf_rotation=90,
    # leaf_font_size=10,
    # show_contracted=False,
    link_color_func=lambda x: link_cols[x],
    above_threshold_color="k",
    # distance_sort="descending",
    ax=plt.gca(),
)
plt.axhline(max_dist, color="grey", linestyle="--", linewidth=2)
# for i, s in enumerate(labels_str):
#     plt.text(
#         0.8,
#         0.95 - i * 0.04,
#         s,
#         transform=plt.gca().transAxes,
#         va="top",
#         color=cluster_colors[i],
#     )

fig.patch.set_facecolor("white")

dn_labels = shc.fcluster(Z, max_dist, criterion="distance")

# display the dendrogram
# plt.title("Dendrogram")
# plt.ylabel("Distance")
plt.axis("off")
dend_path = "/tmp/dendrogram.png"
plt.savefig(dend_path, bbox_inches="tight", pad_inches=0, dpi=300)
plt.show()
print(len(np.unique(dn_labels)))
roi_num_dict[hroi_prefix] = len(np.unique(dn_labels))

# %%
# put mat and dendrogram together
import cv2
# fig, axs = plt.subplots(2, 1, figsize=(16, 18))
# axs = axs.flatten()
big_im = np.zeros((2500, 2000, 4), dtype=np.float32)
im = plt.imread(dend_path)
# trim blank space
im = im[im.sum(axis=1).sum(axis=1) != 255 * 3]
print(im.shape)
# resize
new_width = 2000
new_height = int(im.shape[0] * new_width / im.shape[1])
im = cv2.resize(im, (new_width, new_height))
print(im.shape)
big_im[:im.shape[0], :, :] = im
# axs[0].imshow(im)
# axs[0].axis("off")
im = plt.imread(mat_path)
# trim blank space
im = im[im.sum(axis=1).sum(axis=1) != 255 * 3]
# resize
new_width = 2000
new_height = int(im.shape[0] * new_width / im.shape[1])
im = cv2.resize(im, (new_width, new_height))
print(im.shape)
big_im[-im.shape[0]:, :, :] = im
# axs[1].imshow(im)
# axs[1].axis("off")
# remove space between subplots
# plt.subplots_adjust(wspace=0, hspace=0)
fig = plt.figure(figsize=(16, 16))
plt.imshow(big_im.transpose(1, 0, 2))
plt.axis("off")
plt.savefig("/workspace/figs/veROIdendrogram.jpeg", bbox_inches="tight", pad_inches=0, dpi=96)
plt.show()

# %%
vi_dict = {}
kvi_dict = {}
for i in np.unique(dn_labels):
    cluster_voxel_indices = []
    labels = (dn_labels == i).nonzero()[0]
    for l in labels:
        voxel_indices = (km_labels == l).nonzero()[0]
        cluster_voxel_indices.append(voxel_indices)
    cluster_voxel_indices = np.concatenate(cluster_voxel_indices)
    cluster_voxel_indices.sort()
    cluster_voxel_indices = cluster_voxel_indices
    vi_dict[i] = cluster_voxel_indices
    kvi_dict[i] = labels
    # print(f"cluster {i}")
    # print(cluster_voxel_indices.shape)
# %%
torch.save(vi_dict, os.path.join(km_save_dir, "vi_dict.pt"))
# %%
random_vi_dict = {}
start = 0
all_vi = np.arange(0, sum(lengths))
np.random.seed(0)
np.random.shuffle(all_vi)
for i in vi_dict:
    random_vi_dict[i] = all_vi[start : start + len(vi_dict[i])]
    start += len(vi_dict[i])

for i in random_vi_dict:
    assert len(random_vi_dict[i]) == len(vi_dict[i])

hroi_prefix = "random_m"
vi_dict = random_vi_dict
# %%
labels = list(vi_dict.keys())

sums = []
start = 0
rois = {}
for subject_id, length in zip(subject_ids, lengths):
    rois[subject_id] = {}
    end = start + length
    for i_k, vi in vi_dict.items():
        sub_vi = vi[start <= vi]  # this will lose the starting voxel
        sub_vi = sub_vi[sub_vi < end]
        rois[subject_id][i_k] = sub_vi - start
        sums.append(len(sub_vi))
    start += length
# # %%
mat = np.zeros((len(subject_ids), len(vi_dict)))
for i in range(len(subject_ids)):
    for j in range(len(vi_dict)):
        if j + 1 in rois[subject_ids[i]]:
            mat[i, j] = len(rois[subject_ids[i]][j + 1])
        else:
            mat[i, j] = 0
fig = plt.figure(figsize=(20, 10))
sns.heatmap(
    mat,
    cmap="Reds",
    xticklabels=list(vi_dict.keys()),
    yticklabels=subject_ids,
    annot=True,
    fmt=".0f",
    vmax=10000,
)
plt.show()

# %%
min_length = 0
for subject_id in rois.keys():
    save_dir = os.path.join(subject_data_dir, f"{subject_id}/roi")
    os.makedirs(save_dir, exist_ok=True)
    for i_k in vi_dict.keys():
        path = os.path.join(save_dir, f"{hroi_prefix}_{i_k}.npy")
        if i_k in rois[subject_id]:
            indices = rois[subject_id][i_k]
        else:
            indices = []
            indices = np.array(indices)
        if len(indices) < min_length:
            indices = np.array([])
        np.save(path, indices)

# %%
roi_num_dict = {}
for max_dist, hroi_prefix in zip(
    [85, 50, 40, 27.25, 18, 10, 6, 1],
    [
        "veroi_lll",
        "veroi_ll",
        "veroi_l",
        "veroi_m",
        "veroi_s",
        "veroi_ss",
        "veroi_sss",
        "veroi_extreme",
    ],
):
    dend = shc.dendrogram(
        Z,
        color_threshold=max_dist,
        # truncate_mode="lastp",
    )

    dn_labels = shc.fcluster(Z, max_dist, criterion="distance")
    plt.title(hroi_prefix)
    plt.ylabel("Distance")
    plt.show()

    roi_num_dict[hroi_prefix] = len(np.unique(dn_labels))
    vi_dict = {}
    kvi_dict = {}
    for i in np.unique(dn_labels):
        cluster_voxel_indices = []
        labels = (dn_labels == i).nonzero()[0]
        for l in labels:
            voxel_indices = (km_labels == l).nonzero()[0]
            cluster_voxel_indices.append(voxel_indices)
        cluster_voxel_indices = np.concatenate(cluster_voxel_indices)
        cluster_voxel_indices.sort()
        cluster_voxel_indices = cluster_voxel_indices
        vi_dict[i] = cluster_voxel_indices
        kvi_dict[i] = labels
        # print(f"cluster {i}")
        # print(cluster_voxel_indices.shape)
    labels = list(vi_dict.keys())

    sums = []
    start = 0
    rois = {}
    for subject_id, length in zip(subject_ids, lengths):
        rois[subject_id] = {}
        end = start + length
        for i_k, vi in vi_dict.items():
            sub_vi = vi[start <= vi]  # this will lose the starting voxel
            sub_vi = sub_vi[sub_vi < end]
            rois[subject_id][i_k] = sub_vi - start
            sums.append(len(sub_vi))
        start += length

    min_length = 0
    for subject_id in rois.keys():
        save_dir = os.path.join(subject_data_dir, f"{subject_id}/roi")
        os.makedirs(save_dir, exist_ok=True)
        for i_k in vi_dict.keys():
            path = os.path.join(save_dir, f"{hroi_prefix}_{i_k}.npy")
            if i_k in rois[subject_id]:
                indices = rois[subject_id][i_k]
            else:
                indices = []
                indices = np.array(indices)
            if len(indices) < min_length:
                indices = np.array([])
            np.save(path, indices)


# %%
print(roi_num_dict)
# %%
{
    "veroi_lll": 4,
    "veroi_ll": 7,
    "veroi_l": 11,
    "veroi_m": 18,
    "veroi_s": 37,
    "veroi_ss": 109,
    "veroi_sss": 268,
    "veroi_extreme": 1000,
}
# %%
