# %%
from collections import OrderedDict
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

# plt.style.use("dark_background")
# %%
# km_save_dir = '/data/script_tiny/xaaa/'
# subject_data_dir = '/data/VWET/'
# roi_prefix = 'vwtiny'
# path = "/data/results/xaaa/crn_tiny/stage_2/soup.pth"
km_save_dir = "/data/script_base/xdabc_mkii/"
subject_data_dir = "/data/VWET/"
roi_prefix = "weight"
hroi_prefix = "veroi"

weights_dir = "/data/results/xdabb/dino_bb/big_model/"
weights_fname = "soup.pth"
import glob

weight_paths = glob.glob(os.path.join(weights_dir, "**", weights_fname), recursive=True)

k = 1000
K = k

from torchmetrics.functional import (
    pairwise_cosine_similarity,
    pairwise_euclidean_distance,
)


# %%
c = np.load(os.path.join(km_save_dir, f"centroids.npy"))
labels = np.load(os.path.join(km_save_dir, f"labels.npy"))
subject_ids = np.load(os.path.join(km_save_dir, f"subject_ids.npy"))
lengths = np.load(os.path.join(km_save_dir, f"lengths.npy"))
# %%
from torchmetrics.functional import (
    pairwise_cosine_similarity,
    pairwise_euclidean_distance,
)
import seaborn as sns

# km_centroids = km.centroids
# km_labels = labels

km_centroids = np.load(os.path.join(km_save_dir, f"centroids.npy"))
km_labels = np.load(os.path.join(km_save_dir, f"labels.npy"))

plt.style.use("seaborn-whitegrid")
c = km_centroids
c = torch.from_numpy(c)
# d = torch.cdist(c, c)
d = pairwise_cosine_similarity(c)
# d = pairwise_euclidean_distance(c)
d = d.cpu().numpy()
g = sns.clustermap(
    # c.cpu().numpy(),
    d,
    figsize=(16, 16),
    method="ward",
    # method="weighted",
    # metric="euclidean",
    # metric="cosine",
    yticklabels=0,
    xticklabels=0,
    row_cluster=True,
    col_cluster=True,
    # cmap="coolwarm",
)
# xtick_labels = [int(tick.get_text()) for tick in g.ax_heatmap.get_xticklabels()]
# g.ax_heatmap.tick_params(axis="y", labelleft=True, labelright=True)
# print(xtick_labels)
# unique, counts = np.unique(km_labels, return_counts=True)
# reordered_counts = counts[xtick_labels]
# print(reordered_counts)

# disable dendrogram
g.ax_row_dendrogram.set_visible(False)
g.ax_col_dendrogram.set_visible(False)
# disable colorbar
g.cax.set_visible(False)

# mat_path = os.path.join(f"/tmp/mat.pdf")
# plt.savefig(mat_path, bbox_inches="tight", pad_inches=0)
mat_path = os.path.join(f"/tmp/mat.png")
plt.savefig(mat_path, bbox_inches="tight", pad_inches=0, dpi=300)
plt.show()
# fig, ax = plt.subplots(figsize=(32, 1))
# sns.heatmap(
#     reordered_counts.reshape(1, -1), cbar=False, cmap="Reds", xticklabels=xtick_labels
# )
# plt.show()
# %%
plt.style.use("seaborn-white")
import scipy.cluster.hierarchy as shc

Z = shc.linkage(d, method="ward", optimal_ordering=False)
# %%
# fig, ax = plt.subplots(figsize=(10, 10))
roi_num_dict = {}
max_dist = 27.25
hroi_prefix = "veroi"  # 18
# max_dist = 40
# hroi_prefix = "veroi_l"
# max_dist = 18
# hroi_prefix = "veroi_s"
# max_dist = 10
# hroi_prefix = "veroi_ss"
# max_dist = 6
# hroi_prefix = "veroi_sss"  # 268
# max_dist = 1
# hroi_prefix = "veroi_extreme" # 1000

# sample 18 colors from gist_rainbow
cmap = plt.cm.get_cmap("rainbow", 19)
colors = [cmap(i) for i in range(19)]
dn_labels = shc.fcluster(Z, max_dist, criterion="distance")


# use 18 colors to plot dendrogram
def rgb_hex(color):
    """converts a (r,g,b) color (either 0-1 or 0-255) to its hex representation.
    for ambiguous pure combinations of 0s and 1s e,g, (0,0,1), (1/1/1) is assumed."""
    message = "color must be an iterable of length 3."
    assert hasattr(color, "__iter__"), message
    assert len(color) == 3, message
    if all([(c <= 1) & (c >= 0) for c in color]):
        color = [int(round(c * 255)) for c in color]  # in case provided rgb is 0-1
    color = tuple(color)
    return "#%02x%02x%02x" % color


def get_cluster_colors(
    n_clusters, my_set_of_18_rgb_colors, alpha=0.8, alpha_outliers=0.05
):
    cluster_colors = my_set_of_18_rgb_colors
    # cluster_colors = [c + [alpha] for c in cluster_colors]
    # outlier_color = [0, 0, 0, alpha_outliers]
    outlier_color = my_set_of_18_rgb_colors[-1]
    return [cluster_colors[i % 20] for i in range(n_clusters)] + [outlier_color]
    # return [cluster_colors[i % 18] for i in range(n_clusters)] + [outlier_color]

labels_str = [
    f"cluster #{l}: n={c}\n" for (l, c) in zip(*np.unique(dn_labels, return_counts=True))
]
n_clusters = len(labels_str)

cluster_colors = [
    rgb_hex(c[:-1])
    for c in get_cluster_colors(n_clusters, colors, alpha=0.8, alpha_outliers=0.05)
]
cluster_colors_array = [cluster_colors[l] for l in dn_labels]
link_cols = {}
for i, i12 in enumerate(Z[:, :2].astype(int)):
    c1, c2 = (link_cols[x] if x > len(Z) else cluster_colors_array[x] for x in i12)
    link_cols[i + 1 + len(Z)] = c1 if c1 == c2 else "k"

# plot dendrogram with colored clusters
fig = plt.figure(figsize=(16, 4))
# plt.title("Hierarchical Clustering Dendrogram")
# plt.xlabel("Data points")
# plt.ylabel("Distance")

# plot dendrogram based on clustering results
shc.dendrogram(
    Z,
    labels=dn_labels,
    color_threshold=max_dist,
    # truncate_mode="level",
    # p=1,
    # show_leaf_counts=True,
    # leaf_rotation=90,
    # leaf_font_size=10,
    # show_contracted=False,
    link_color_func=lambda x: link_cols[x],
    above_threshold_color="k",
    # distance_sort="descending",
    ax=plt.gca(),
)
plt.axhline(max_dist, color="grey", linestyle="--", linewidth=2)
# for i, s in enumerate(labels_str):
#     plt.text(
#         0.8,
#         0.95 - i * 0.04,
#         s,
#         transform=plt.gca().transAxes,
#         va="top",
#         color=cluster_colors[i],
#     )

fig.patch.set_facecolor("white")

dn_labels = shc.fcluster(Z, max_dist, criterion="distance")

# display the dendrogram
# plt.title("Dendrogram")
# plt.ylabel("Distance")
plt.axis("off")
dend_path = "/tmp/dendrogram.png"
plt.savefig(dend_path, bbox_inches="tight", pad_inches=0, dpi=300)
plt.show()
print(len(np.unique(dn_labels)))
roi_num_dict[hroi_prefix] = len(np.unique(dn_labels))

    # %%
# put mat and dendrogram together
import cv2
# fig, axs = plt.subplots(2, 1, figsize=(16, 18))
# axs = axs.flatten()
big_im = np.zeros((2500, 2000, 4), dtype=np.float32)
im = plt.imread(dend_path)
# trim blank space
im = im[im.sum(axis=1).sum(axis=1) != 255 * 3]
print(im.shape)
# resize
new_width = 2000
new_height = int(im.shape[0] * new_width / im.shape[1])
im = cv2.resize(im, (new_width, new_height))
print(im.shape)
big_im[:im.shape[0], :, :] = im
# axs[0].imshow(im)
# axs[0].axis("off")
im = plt.imread(mat_path)
# trim blank space
im = im[im.sum(axis=1).sum(axis=1) != 255 * 3]
# resize
new_width = 2000
new_height = int(im.shape[0] * new_width / im.shape[1])
im = cv2.resize(im, (new_width, new_height))
print(im.shape)
big_im[-im.shape[0]:, :, :] = im
# axs[1].imshow(im)
# axs[1].axis("off")
# remove space between subplots
# plt.subplots_adjust(wspace=0, hspace=0)
fig = plt.figure(figsize=(16, 16))
plt.imshow(big_im)
plt.axis("off")
# plt.savefig("/workspace/figs/veROIdendrogram.jpeg", bbox_inches="tight", pad_inches=0, dpi=96)
plt.savefig("/workspace/figs/veROIdendrogram.png", bbox_inches="tight", pad_inches=0, dpi=144)
plt.show()
# %%
def plot_colorbar(cmap='rainbow', ticks=np.arange(1, 19), orientation='horizontal'):
    import matplotlib as mpl
    fig, ax = plt.subplots(figsize=(10, 1))
    fig.subplots_adjust(bottom=0.5)
    cmap = mpl.cm.get_cmap(cmap)
    norm = mpl.colors.Normalize(vmin=ticks.min(), vmax=ticks.max())
    cb = mpl.colorbar.ColorbarBase(ax, cmap=cmap, norm=norm, ticks=ticks, orientation=orientation)
    # move ticks to center of bars
    cb.ax.tick_params(labelsize=20)
    plt.savefig("/workspace/figs/veROIdendrogram_colorbar.pdf", bbox_inches="tight", pad_inches=0)
    plt.show()
    
plot_colorbar()
# %%
vi_dict = {}
kvi_dict = {}
for i in np.unique(dn_labels):
    cluster_voxel_indices = []
    labels = (dn_labels == i).nonzero()[0]
    for l in labels:
        voxel_indices = (km_labels == l).nonzero()[0]
        cluster_voxel_indices.append(voxel_indices)
    cluster_voxel_indices = np.concatenate(cluster_voxel_indices)
    cluster_voxel_indices.sort()
    cluster_voxel_indices = cluster_voxel_indices
    vi_dict[i] = cluster_voxel_indices
    kvi_dict[i] = labels
    # print(f"cluster {i}")
    # print(cluster_voxel_indices.shape)
# %%
# torch.save(vi_dict, os.path.join(km_save_dir, "vi_dict.pt"))
# %%
labels = list(vi_dict.keys())

sums = []
start = 0
rois = {}
for subject_id, length in zip(subject_ids, lengths):
    rois[subject_id] = {}
    end = start + length
    for i_k, vi in vi_dict.items():
        sub_vi = vi[start <= vi]  # this will lose the starting voxel
        sub_vi = sub_vi[sub_vi < end]
        rois[subject_id][i_k] = sub_vi - start
        sums.append(len(sub_vi))
    start += length
# # %%
mat = np.zeros((len(subject_ids), len(vi_dict)))
for i in range(len(subject_ids)):
    for j in range(len(vi_dict)):
        if j + 1 in rois[subject_ids[i]]:
            mat[i, j] = len(rois[subject_ids[i]][j + 1])
        else:
            mat[i, j] = 0
fig = plt.figure(figsize=(10, 5))
# xticklabels = list(vi_dict.keys())
# xticklabels = [f"veROI {i}" for i in xticklabels]
g = sns.heatmap(
    mat,
    cmap="Reds",
    xticklabels=list(vi_dict.keys()),
    yticklabels=subject_ids,
    annot=False,
    fmt=".0f",
    vmax=9000,
    cbar=True,
)
# add text to colorbar
cbar = g.collections[0].colorbar
cbar.ax.set_ylabel("Voxel Counts", fontsize=16)
# reduce colorbar tick count
cbar.set_ticks([0, 2000, 4000, 6000, 8000])
plt.xlabel("veROI", fontsize=18)
# plt.ylabel("Subject", fontsize=18)
# increase tick label size
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.savefig("/workspace/figs/veROIheatmap.pdf", bbox_inches="tight", pad_inches=0)
plt.show()

# %%
