# %%
from collections import OrderedDict
import glob
import json
import sys
import traceback
import re
import logging
from time import sleep
from einops import repeat
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm
from config import AutoConfig

from config_utils import flatten_dict, load_from_yaml

from IPython.display import display, HTML, clear_output

from datamodule import AllDatamodule, build_dm

import glob

plt.style.use("dark_background")
# %%


def load_cfg(run):
    path = glob.glob(run + "/**/hparams.yaml", recursive=True)
    print(path)
    path = path[0]
    cfg = load_from_yaml(path)
    return cfg


def load_voxel_metric(run):
    path = glob.glob(run + "/**/stage=TEST*.npy", recursive=True)
    print(path)
    path = path[0]
    voxel_metric = np.load(path, allow_pickle=True).item()
    return voxel_metric


# %%
def load_score(run):
    cfg = load_cfg(run)
    voxel_metric = load_voxel_metric(run)
    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    score_dict = {}
    for subject in voxel_metric.keys():
        print(
            subject, voxel_metric[subject][f"TEST/PearsonCorrCoef/{subject}/all"].shape
        )
        scores = voxel_metric[subject][f"TEST/PearsonCorrCoef/{subject}/all"]
        vi = dm.dss[0][subject].voxel_index
        if vi != ...:
            vi = vi.cpu().numpy()
            # print(vi.shape, vi.max(), vi.min())
            rvi = {v: k for k, v in enumerate(vi)}
            scores = scores[[rvi[i] for i in np.arange(len(vi))]]
        score_dict[subject] = scores
    return score_dict


# %%
nm_run = "/data/results/xdabb/dino_mania/big_model/run_69c71_00000_0_OPTIMIZER_LR=0.0030_2023-05-12_21-10-25/"
nm_score_dict = load_score(nm_run)

# %%
afo_run = "/data/results/xgaa/yesgt_1/"
afo_score_dict = load_score(afo_run)


# %%
def p_metric(y, y_pred):
    y = y.astype(np.float32)
    y_pred = y_pred.astype(np.float32)
    y = torch.from_numpy(y).cuda()
    y_pred = torch.from_numpy(y_pred).cuda()
    from metrics import vectorized_correlation

    p = vectorized_correlation(y, y_pred)
    p = p.cpu().numpy()
    return p


def dark_score(dark_postfix):
    cfg: AutoConfig = load_from_yaml("/workspace/configs/dino_mania.yaml")
    cfg.DATASET.DARK_POSTFIX = dark_postfix
    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    score_dict = {}
    for subject in cfg.DATASET.SUBJECT_LIST:
        print(subject, dark_postfix)
        dl = dm.test_dataloader(subject=subject)

        ys, darks = [], []
        for batch in dl:
            y = batch[1]
            dark = batch[-1]
            ys.append(y)
            darks.append(dark)
        ys = torch.stack(sum(ys, [])).numpy()
        dark = torch.stack(sum(darks, [])).numpy()

        s = p_metric(ys, dark)

        score_dict[subject] = s

    return score_dict


# %%
# d = dark_score(".mania_veroi_m_gen2_darkfull")
# %%
darks = [
    ".mania_veroi_m_gen1",
    ".mania_veroi_m_gen2_darkfull",
    ".mania_veroi_m_gen3_darkfull",
    ".veroi_m_gen2n_darkgt_darkfull",
    ".random_m_gen1",
    ".random_m_gen2_darkfull"
    # ".mania_veroi_m_gen2_darkself",
]
# %%
all_d = {}
for dark in darks:
    all_d[dark] = dark_score(dark)
# %%
all_d['nm'] = nm_score_dict
all_d['afo'] = afo_score_dict
# %%
os.mkdir("/data/results/afogen")
torch.save(all_d, "/data/results/afogen/afogen_score_dict.pt")
# %%
all_d = torch.load("/data/results/afogen/afogen_score_dict.pt")
cfg: AutoConfig = load_from_yaml("/workspace/configs/dino_mania.yaml")
dm: AllDatamodule = build_dm(cfg)
dm.setup()
for row in all_d.keys():
    d = all_d[row]
    
    all_nsd = []
    all_nsd_nced = []
    all_data = []
    for subject in d.keys():
        if "NSD" in subject:
            all_nsd.append(d[subject])
            all_nsd_nced.append(d[subject]**2/(dm.dss[0][subject].noise_ceiling + 1e-8))
        all_data.append(d[subject])
    all_nsd = np.concatenate(all_nsd)
    all_nsd_nced = np.concatenate(all_nsd_nced)
    all_data = np.concatenate(all_data)
    all_d[row]['all_nsd'] = all_nsd
    all_d[row]['all_nsd_nced'] = all_nsd_nced
    all_d[row]['all_data'] = all_data
# %%
datas = []
for row in all_d.keys():
    d = all_d[row]
    for col in d.keys():
        s = d[col]
        if "NSD" in col:
            continue
        if "nced" in col:
            s = np.nanmedian(s)
        else:
            s = np.nanmean(s)
        datas.append([row, col, s])
df = pd.DataFrame(datas, columns=["row", "col", "score"])
# %%
rowcol_df = df.pivot(index="row", columns="col", values="score")
# reorder columns
rowcol_df = rowcol_df[["all_nsd", "EEG", "MEG", "HCP", "ALG", "all_data", "all_nsd_nced"]]
# reorder rows
rowcol_df = rowcol_df.loc[["nm", ".mania_veroi_m_gen1", ".mania_veroi_m_gen2_darkfull", "afo", ".mania_veroi_m_gen3_darkfull", ".veroi_m_gen2n_darkgt_darkfull", ".random_m_gen1", ".random_m_gen2_darkfull"]]
# %%
# rename columns
rowcol_df.columns = ["NSD", "EEG", "MEG", "HCP", "ALG", "ALL", "NSD\n(NC)"]
# rename rows
rowcol_df.index = ["NaiveMix", "veROIGen1", "veROIGen2",  "AFO(Gen2Distill)", "veROIGen3", "veROIGen2(NoDK)", "RandomROIGen1", "RandomROIGen2"]
# %%
# format to .4f
rowcol_df = rowcol_df.applymap(lambda x: f"{x:.3f}")
# %%
rowcol_df
# %%
# save to csv
rowcol_df.to_csv("/workspace/figs/afogen_score.csv")

# %%
gen2_diff_gen1 = all_d[".mania_veroi_m_gen2_darkfull"]["NSD_01"] - all_d[".mania_veroi_m_gen1"]["NSD_01"]
# %%
afo_diff_nm = all_d["afo"]["NSD_01"] - all_d["nm"]["NSD_01"]
# %%
afo_diff_nm.mean()
# %%
torch.save(gen2_diff_gen1, "/data/results/afogen/gen2_diff_gen1.pt")
torch.save(afo_diff_nm, "/data/results/afogen/afo_diff_nm.pt")
# %%
import cortex
plt.style.use("default")
pngs = []
for i in range(1, 9):
    s = f"subj0{i}"
    subject = f"NSD_{i:02d}"
    afo_diff_nm = all_d["afo"][subject] - all_d["nm"][subject]
    mask_dir = f"/data/algonauts2023/{s}/roi_masks"
    all_mask = []
    for hemi in ["lh", "rh"]:
        mask_path = os.path.join(mask_dir, f"{hemi}.all-vertices_fsaverage_space.npy")
        mask = np.load(mask_path)
        all_mask.append(mask)
    all_mask = np.concatenate(all_mask, axis=0)

    all_c = np.zeros(all_mask.shape[0])
    all_c[all_mask == 1] = afo_diff_nm

    vertex = cortex.Vertex(all_c, "fsaverage", vmin=-0.1, vmax=0.1, cmap="coolwarm")

    tmp_png = f'/tmp/c{i}.png'
    cortex.quickshow(vertex, with_colorbar=True, colorbar_ticks=[-0.1, 0, 0.1])
    # increase colorbar font size, reduce tick frequency
    plt.savefig(tmp_png, dpi=300)
    # plt.show()
    plt.close()

    from PIL import Image
    im = Image.open(tmp_png)
    width, height = im.size
    left = width * 0.35
    right = width * 0.65
    top = height * 0.2
    bottom = height * 0.99 

    cropped_im = im.crop((left, top, right, bottom))

    cropped_im.save(tmp_png)
    pngs.append(tmp_png)
# # %%
# fig, ax = plt.subplots(figsize=(10, 10))
# plt.imshow(plt.imread(tmp_png))
# plt.axis("off")
# plt.show()
# %%
fig, axs = plt.subplots(2, 4, figsize=(20, 12))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(plt.imread(pngs[i]))
    ax.axis("off")
plt.tight_layout()
plt.savefig("/workspace/figs/all_nsd_afo_diff_nm.pdf")
plt.show()
# %%
