# %%
import argparse
import copy
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import pytorch_lightning as pl

import cortex
from matplotlib.pyplot import cm
from config import AutoConfig

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

from datamodule import AllDatamodule, build_dm
from models import VEModel

# plt.style.use("dark_background")


# %%
all_d = torch.load("/data/results/afogen/afogen_score_dict.pt")

# %%
def get_all_mask(i_sub):
    s = f"subj{i_sub:02d}"
    mysub = f"NSD_{i_sub:02d}"

    mask_dir = f"/data/algonauts2023/{s}/roi_masks"
    all_mask = []
    for hemi in ["lh", "rh"]:
        mask_path = os.path.join(mask_dir, f"{hemi}.all-vertices_fsaverage_space.npy")
        mask = np.load(mask_path)
        all_mask.append(mask)
    all_mask = np.concatenate(all_mask, axis=0)

    return all_mask


# %%
def get_scores(sub, model):
    if model == "":
        return 0
    s = all_d[model][sub]
    return s


# fig, axs = plt.subplots(8, 4, figsize=(10, 20))
i = 0
for i_sub in range(1, 9):
    sub = f"NSD_{i_sub:02d}"
    all_mask = get_all_mask(i_sub)
    for model in [
        ["nm", ""],
        [".mania_veroi_m_gen1", "nm"],
        ["afo", ".mania_veroi_m_gen1"],
        ["afo", ""],
        ['afo', 'nm'],
    ]:
        score1, score2 = get_scores(sub, model[0]), get_scores(sub, model[1])
        score_diff = score1 - score2

        full_data = np.zeros(all_mask.shape)
        full_data[all_mask == 1] = score_diff

        if model[1] == "":
            vertex_data = cortex.Vertex(
                full_data, "fsaverage", cmap="Purples", vmax=1.0, vmin=0
            )
        else:
            vertex_data = cortex.Vertex(
                full_data, "fsaverage", cmap="bwr", vmax=0.1, vmin=-0.1
            )

        png_path = f"/tmp/{i}.png"
        fig = cortex.quickflat.make_figure(
            vertex_data,
            with_curvature=False,
            with_rois=False,
            with_labels=False,
            with_sulci=False,
            with_colorbar=False,
        )
        # plt.title(mysub)
        plt.savefig(png_path)
        plt.close()

        tmp_png = png_path
        from PIL import Image

        im = Image.open(tmp_png)
        width, height = im.size
        left = width * 0.35
        right = width * 0.65
        top = height * 0.2
        # bottom = height * 0.99
        bottom = height * 0.9

        cropped_im = im.crop((left, top, right, bottom))

        cropped_im.save(tmp_png)

        i += 1

# %%
fig, axs = plt.subplots(8, 5, figsize=(10, 17))
axs = axs.flatten()
i = 0
for i_sub in range(1, 9):
    for model, title in zip([
        ["nm", ""],
        [".mania_veroi_m_gen1", "nm"],
        ["afo", ".mania_veroi_m_gen1"],
        ["afo", ""],
        ['afo', 'nm'],
    ], ["NM", "S1 - NM", "S3 - S1", "S3", "S3 - NM"]):
        axs[i].imshow(plt.imread(f"/tmp/{i}.png"))
        # axs[i].axis("off")
        if i_sub == 1:
            axs[i].set_title(title, fontsize=20)
        if model == ["nm", ""]:
            axs[i].set_ylabel(f"NSD_{i_sub:02d}", fontsize=20)
        
        axs[i].set_xticks([])
        axs[i].set_yticks([])
        
        # remove the black frame
        for spine in axs[i].spines.values():
            spine.set_visible(False)
        
        i += 1
plt.tight_layout()
plt.savefig("/workspace/figs/supfig1_afocortex.pdf")
plt.show()
# %%
import pylab as pl
import numpy as np

a = np.array([[0,1]])
pl.figure(figsize=(9, 1))
img = pl.imshow(a, cmap="Purples")
pl.gca().set_visible(False)
cax = pl.axes([0.1, 0.2, 0.8, 0.6])
pl.colorbar(orientation="horizontal", cax=cax)
# set the colorbar ticks and tick labels
cax.set_xticks([0, 1])
cax.set_xticklabels(["0", "1"], fontsize=20)

pl.savefig("/workspace/figs/supfig1_afocortex_colorbar1.pdf", bbox_inches='tight')
pl.show()
# %%
import pylab as pl
import numpy as np

a = np.array([[-0.1, 0.1]])
pl.figure(figsize=(9, 1))
img = pl.imshow(a, cmap="bwr")
pl.gca().set_visible(False)
cax = pl.axes([0.1, 0.2, 0.8, 0.6])
pl.colorbar(orientation="horizontal", cax=cax)
cax.set_xticks([-0.1, 0, 0.1])
cax.set_xticklabels(["-0.1", "0", "0.1"], fontsize=20)


pl.savefig("/workspace/figs/supfig1_afocortex_colorbar2.pdf", bbox_inches='tight')
pl.show()
# %%
def plot_2x4(model, title):
    fig, axs = plt.subplots(2, 4, figsize=(17, 10))
    axs = axs.flatten()
    for i_sub in range(1, 9):
        sub = f"NSD_{i_sub:02d}"
        all_mask = get_all_mask(i_sub)

        score1, score2 = get_scores(sub, model[0]), get_scores(sub, model[1])
        score_diff = score1 - score2

        full_data = np.zeros(all_mask.shape)
        full_data[all_mask == 1] = score_diff

        if model[1] == "":
            vertex_data = cortex.Vertex(
                full_data, "fsaverage", cmap="Purples", vmax=1.0, vmin=0
            )
        else:
            vertex_data = cortex.Vertex(
                full_data, "fsaverage", cmap="bwr", vmax=0.1, vmin=-0.1
            )

        png_path = f"/tmp/{i}.png"
        f = cortex.quickflat.make_figure(
            vertex_data,
            with_curvature=False,
            with_rois=False,
            with_labels=False,
            with_sulci=False,
            with_colorbar=False,
        )
        # plt.title(mysub)
        plt.savefig(png_path)
        plt.close()

        tmp_png = png_path
        from PIL import Image

        im = Image.open(tmp_png)
        width, height = im.size
        left = width * 0.35
        right = width * 0.65
        top = height * 0.2
        # bottom = height * 0.99
        bottom = height * 0.9

        cropped_im = im.crop((left, top, right, bottom))

        cropped_im.save(tmp_png)

        axs[i_sub - 1].imshow(plt.imread(f"/tmp/{i}.png"))
        axs[i_sub - 1].set_title(sub, fontsize=20)
        axs[i_sub - 1].set_xticks([])
        axs[i_sub - 1].set_yticks([])
        for spine in axs[i_sub - 1].spines.values():
            spine.set_visible(False)
        
    plt.tight_layout()
    # plt.title(title, fontsize=20)
    plt.savefig(f"/workspace/figs/supfig1_{title}.pdf")
    plt.show()
# %%
plot_2x4(["afo", ""], "AFO")
# %%
plot_2x4(["afo", "nm"], "AFO-NM")
# %%
plot_2x4(["afo", ".mania_veroi_m_gen1"], "S3-S1")
# %%
plot_2x4(["afo", "nm"], "AFO-NM")
# %%
