# %%
from functools import partial
import os
from matplotlib.lines import Line2D

import numpy as np
import nilearn
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import cm, ticker
import cortex
import torch

# %%
from config_utils import load_from_yaml
from datamodule import AllDatamodule, build_dm
from models import VEModel

device = "cuda:0"
# device = 'cpu'
# %%
# %%
cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
dm: AllDatamodule = build_dm(cfg)
dm.setup()
subject = "NSD_01"


def get_top5(dark_postfix=".mania_veroi_m_gen2_darkfull", rois=["all"]):
    cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
    cfg.DATASET.ROIS = rois
    cfg.DATASET.DARK_POSTFIX = dark_postfix
    dm: AllDatamodule = build_dm(cfg)
    dm.setup()
    subject = "NSD_01"
    dl = dm.val_dataloader(subject=subject)

    ys, darks = [], []
    for batch in dl:
        y = batch[1]
        dark = batch[-1]
        ys.append(y)
        darks.append(dark)
    ys = torch.stack(sum(ys, [])).numpy()
    dark = torch.stack(sum(darks, [])).numpy()

    from metrics import vectorized_correlation
    corr = vectorized_correlation(dark, ys)
    
    # pairwise correlation
    dist_mat = np.zeros((ys.shape[0], ys.shape[0]))
    for i in range(ys.shape[0]):
        for j in range(ys.shape[0]):
            dist_mat[i, j] = np.corrcoef(dark[i], ys[j])[0, 1]
    # # l2 distance
    # dist_mat = np.zeros((ys.shape[0], ys.shape[0]))
    # for i in range(ys.shape[0]):
    #     for j in range(ys.shape[0]):
    #         dist_mat[i, j] = ((dark[i] - ys[j]) ** 2).mean()
    top1_idx = dist_mat.argmax(axis=1)
    top1_acc = (top1_idx == np.arange(ys.shape[0])).sum() / ys.shape[0]
    top5_idx = np.argsort(dist_mat, axis=1)[:, -5:][:, ::-1]
    top5_acc = (top5_idx == np.arange(ys.shape[0])[:, None]).sum() / ys.shape[0]
    return top1_acc, top5_acc, top1_idx, top5_idx, dist_mat, corr


# # %%
# rois = [
#     ["all"],
#     ["early"],
#     ["mid"],
#     ["late"],
# ]
# rois += [[f"veroi_m_{i}"] for i in range(1, 19)]
# top5_idx_dict = {}
# top1_acc_dict = {}
# top5_acc_dict = {}
# corr_dict = {}
# for roi in rois:
#     top1_acc, top5_acc, top1_idx, top5_idx, dist_mat, corr = get_top5(rois=roi)
#     roi = roi[0]
#     top5_idx_dict[roi] = top5_idx
#     top1_acc_dict[roi] = top1_acc
#     top5_acc_dict[roi] = top5_acc
#     corr_dict[roi] = corr
#     print(roi, top1_acc, top5_acc, corr.mean())
# %%
# torch.save((top5_idx_dict, top1_acc_dict, top5_acc_dict, corr_dict), "/tmp/stuff.pt")
# %%
(top5_idx_dict, top1_acc_dict, top5_acc_dict, corr_dict) = torch.load("/tmp/stuff.pt")
for roi, top5_idx in top5_idx_dict.items():
    fig, axs = plt.subplots(3, 6, figsize=(11, 6))
    for ax in axs.flatten():
        ax.axis("off")
    for i in range(3):
        gt_img = dm.dss[1]["NSD_01"].__getitem__(i)[0]
        gt_img = gt_img.numpy().transpose(1, 2, 0) * [0.229, 0.224, 0.225] + [
            0.485,
            0.456,
            0.406,
        ]
        top5_imgs = [dm.dss[1]["NSD_01"].__getitem__(idx)[0] for idx in top5_idx[i]]
        top5_imgs = [
            img.numpy().transpose(1, 2, 0) * [0.229, 0.224, 0.225]
            + [0.485, 0.456, 0.406]
            for img in top5_imgs
        ]

        for j in range(6):
            if j == 0:
                axs[i, j].imshow(np.clip(gt_img, 0, 1))
                if i == 0:
                    axs[i, j].set_title("GT", fontsize=20)
            else:
                axs[i, j].imshow(np.clip(top5_imgs[j - 1], 0, 1))
                if i == 0:
                    axs[i, j].set_title(f"Top{j}", fontsize=20)
    # plt.suptitle(f"ROI: {roi}", fontsize=20)
    plt.tight_layout()
    plt.savefig(f"/workspace/figs/top5_{roi}.pdf")
    plt.show()

    break
# %%
# bar plot of top1 and top5 accuracy for each roi, correlation on the same plot
rois = list(top1_acc_dict.keys())
x_rois = [s.replace("_m_", "").replace('roi', "ROI_").replace("veROI_", "") for s in rois]
top1_accs = [top1_acc_dict[roi] for roi in rois]
top5_accs = [top5_acc_dict[roi] for roi in rois]
corrs = [corr_dict[roi].mean() for roi in rois]
corrs_std = [corr_dict[roi].std() / np.sqrt(corr_dict[roi].shape[0]) for roi in rois]
top1_accs = np.array(top1_accs)
top5_accs = np.array(top5_accs)
corrs = np.array(corrs)
fig, ax = plt.subplots(figsize=(12, 3))
x = np.arange(len(rois))
ax.bar(x, top1_accs, label="top1", alpha=0.8, width=0.4)
# move top5 to the side ot top1
# ax.bar(rois, top5_accs, bottom=top1_accs, label="top5", alpha=0.8)
ax.bar(x, top5_accs - top1_accs, bottom=top1_accs, label="top5", alpha=0.8, width=0.4)
ax.bar(x + 0.4, corrs, label="r", alpha=0.8, width=0.4, yerr=corrs_std)
ax.set_xticklabels(x_rois, rotation=0)
ax.set_xticks(x + 0.2)
ax.set_ylabel("Image Retrieval \nAccuracy", fontsize=12)
# ax.set_xlabel("ROI", fontsize=12)
ax.text(0.025, -0.08, "ROI", transform=ax.transAxes, ha="center", fontsize=12)
# ax.set_title("Image Retrieval Accuracy", fontsize=14)
# right side ticks
ax2 = ax.twinx()
ax2.spines["bottom"].set_visible(False)
# ax2.set_ylim(ax.get_ylim())
# ax2.set_yticks(ax.get_yticks())
# ax2.set_yticklabels(np.round(ax.get_yticks() * 100, 1))
ax2.set_ylabel("Performance Score \nPearson's R", fontsize=12)

# remove top and right spines
ax.spines["top"].set_visible(False)
ax2.spines["top"].set_visible(False)
# ax.spines["right"].set_visible(False)

# add horizontal grid lines
ax.grid(axis="y", linestyle="-", alpha=0.4)

ax.legend()
plt.savefig("/workspace/figs/image_retrieval_accuracy.pdf")
plt.show()
exit()
# %%
### fix an image, and plot the top5 images for each roi

def get_gt_and_top5_image(top5_idx, img_idx):
    gt_img = dm.dss[1]["NSD_01"].__getitem__(img_idx)[0]
    gt_img = gt_img.numpy().transpose(1, 2, 0) * [0.229, 0.224, 0.225] + [
        0.485,
        0.456,
        0.406,
    ]
    top5_imgs = [dm.dss[1]["NSD_01"].__getitem__(idx)[0] for idx in top5_idx[img_idx]]
    top5_imgs = [
        img.numpy().transpose(1, 2, 0) * [0.229, 0.224, 0.225]
        + [0.485, 0.456, 0.406]
        for img in top5_imgs
    ]
    return gt_img, top5_imgs

fig, axs = plt.subplots(3, 6, figsize=(12, 6))
img_idx = 0
for ax in axs.flatten():
    # turn of ticks but and remove spines
    ax.tick_params(axis="both", which="both", length=0)
    # ax.spines["top"].set_visible(False)
    for s in ["top", "right", "left", "bottom"]:
        ax.spines[s].set_visible(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
for i, roi in enumerate(['early', 'mid', 'late']):
    top5_idx = top5_idx_dict[roi]
    gt_img, top5_imgs = get_gt_and_top5_image(top5_idx, img_idx)
    axs[i, 0].imshow(np.clip(gt_img, 0, 1))
    if roi == 'early':
        axs[i, 0].set_title("GT", fontsize=20)
    for ii in range(5):
        axs[i, ii+1].imshow(np.clip(top5_imgs[ii], 0, 1))
        if roi == 'early':
            axs[i, ii+1].set_title(f"Top{ii+1}", fontsize=20)
    axs[i, 0].set_ylabel(roi, fontsize=20, rotation=90)
    
# plt.suptitle(f"Image {img_idx}", fontsize=20)
plt.tight_layout()
plt.savefig(f"/workspace/figs/anatomical_roi_top5.pdf")
plt.show()
# %%
### for 5 image, plot the top4 image for each roi, squeeze the top4 images into one image
def combine_4_image_into_one(imgs):
    imgs = [np.clip(img, 0, 1) for img in imgs]
    img1 = np.concatenate(imgs[:2], axis=0)
    img2 = np.concatenate(imgs[2:], axis=0)
    img = np.concatenate([img1, img2], axis=1)
    return img


fig, axs = plt.subplots(5, 10, figsize=(16, 9))
for ax in axs.flatten():
    # turn of ticks but and remove spines
    ax.tick_params(axis="both", which="both", length=0)
    # ax.spines["top"].set_visible(False)
    for s in ["top", "right", "left", "bottom"]:
        ax.spines[s].set_visible(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
for img_idx in range(5):
    for i, i_roi in enumerate(range(1, 10)):
        roi = f'veroi_m_{i_roi}'
        top5_idx = top5_idx_dict[roi]
        gt_img, top5_imgs = get_gt_and_top5_image(top5_idx, img_idx)
        if i == 0:
            axs[img_idx, 0].imshow(np.clip(gt_img, 0, 1))
        top4_imgs = combine_4_image_into_one(top5_imgs[:4])
        axs[img_idx, i+1].imshow(np.clip(top4_imgs, 0, 1))
        # axs[img_idx, i+1].imshow(np.clip(top5_imgs[0], 0, 1))
        if img_idx == 0:
            axs[img_idx, i+1].set_title(roi.replace("veroi_m_", ""), fontsize=20)
            if i == 0:
                axs[img_idx, 0].set_title("GT", fontsize=20)

# minimize the space between subplots
# plt.subplots_adjust(wspace=0, hspace=0)

plt.tight_layout()
plt.savefig(f"/workspace/figs/veROI_top4_part1.pdf")
plt.show()
# %%
fig, axs = plt.subplots(5, 10, figsize=(16, 9))
for ax in axs.flatten():
    # turn of ticks but and remove spines
    ax.tick_params(axis="both", which="both", length=0)
    # ax.spines["top"].set_visible(False)
    for s in ["top", "right", "left", "bottom"]:
        ax.spines[s].set_visible(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
for img_idx in range(5):
    for i, i_roi in enumerate(range(10, 19)):
        roi = f'veroi_m_{i_roi}'
        top5_idx = top5_idx_dict[roi]
        gt_img, top5_imgs = get_gt_and_top5_image(top5_idx, img_idx)
        if i == 0:
            axs[img_idx, 0].imshow(np.clip(gt_img, 0, 1))
        top4_imgs = combine_4_image_into_one(top5_imgs[:4])
        axs[img_idx, i+1].imshow(np.clip(top4_imgs, 0, 1))
        # axs[img_idx, i+1].imshow(np.clip(top5_imgs[0], 0, 1))
        if img_idx == 0:
            axs[img_idx, i+1].set_title(roi.replace("veroi_m_", ""), fontsize=20)
            if i == 0:
                axs[img_idx, 0].set_title("GT", fontsize=20)

# minimize the space between subplots
# plt.subplots_adjust(wspace=0, hspace=0)

plt.tight_layout()
plt.savefig(f"/workspace/figs/veROI_top4_part2.pdf")
plt.show()
# %%
# %%
fig, axs = plt.subplots(5, 19, figsize=(33, 10))
for ax in axs.flatten():
    # turn of ticks but and remove spines
    ax.tick_params(axis="both", which="both", length=0)
    # ax.spines["top"].set_visible(False)
    for s in ["top", "right", "left", "bottom"]:
        ax.spines[s].set_visible(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
for img_idx in range(5):
    for i, i_roi in enumerate(range(1, 19)):
        roi = f'veroi_m_{i_roi}'
        top5_idx = top5_idx_dict[roi]
        gt_img, top5_imgs = get_gt_and_top5_image(top5_idx, img_idx)
        if i == 0:
            axs[img_idx, 0].imshow(np.clip(gt_img, 0, 1))
        top4_imgs = combine_4_image_into_one(top5_imgs[:4])
        axs[img_idx, i+1].imshow(np.clip(top4_imgs, 0, 1))
        # axs[img_idx, i+1].imshow(np.clip(top5_imgs[0], 0, 1))
        if img_idx == 0:
            axs[img_idx, i+1].set_title(roi.replace("veroi_m_", ""), fontsize=30)
            if i == 0:
                axs[img_idx, 0].set_title("GT", fontsize=30)

# minimize the space between subplots
# plt.subplots_adjust(wspace=0, hspace=0)

plt.tight_layout()
plt.savefig(f"/workspace/figs/veROI_top4_all.pdf")
plt.show()
# %%
