# %%
import os
from matplotlib.lines import Line2D

import numpy as np
import nilearn
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import cm, ticker
import cortex
import torch

# %%
from config_utils import load_from_yaml
from datamodule import AllDatamodule, build_dm
from models import VEModel

cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
dm: AllDatamodule = build_dm(cfg)
dm.setup()

model_args = (
    cfg,
    dm.num_voxel_dict,
    dm.roi_dict,
    dm.neuron_coords_dict,
    dm.noise_ceiling_dict,
)

model = VEModel(*model_args)

sd = torch.load("/data/script_base/xdab/state_dict.pth", map_location="cpu")
model.load_state_dict(sd, strict=False)
model = model.eval()
# %%
mu_dict = {}
gate_dict = {}
with torch.no_grad():
    for subject in model.subject_list:
        mu, gate, _ = model.neck.neuron_projectors[subject].forward(1)
        mu = list(mu.values())[0][0, :, 0, :]
        mu_dict[subject] = mu
        gate_dict[subject] = gate


# %%
from matplotlib.colors import ListedColormap


class Set1Colormap(ListedColormap):
    def __init__(self):
        colors = ["#FF0000", "#eb34cc", "#00FF00", "#0000FF"]
        super().__init__(colors)


cm = Set1Colormap()


def plot_hanabi(mu, gate, subject, ax=None):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    plt.sca(ax)
    plt.scatter(mu[:, 0], mu[:, 1], c=gate.argmax(dim=1), cmap=cm, s=1, rasterized=True)
    # plt.colorbar()
    # for spine in ["top", "right", "bottom", "left"]:
    #     plt.gca().spines[spine].set_visible(False)
    plt.xlim(-1.0, 1.0)
    plt.ylim(-1.0, 1.0)
    ax.grid(axis="both", linestyle="--", alpha=0.5)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_facecolor('#535353')

    for tick in ax.xaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        tick.label1.set_visible(False)
        tick.label2.set_visible(False)
    for tick in ax.yaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        tick.label1.set_visible(False)
        tick.label2.set_visible(False)



# %%
fig, axs = plt.subplots(3, 4, figsize=(12, 8))
for i, subject in enumerate(mu_dict.keys()):
    plot_hanabi(mu_dict[subject], gate_dict[subject], subject, ax=axs[i // 4, i % 4])
    axs[i // 4, i % 4].set_title(subject, fontsize=20)
    # break

legend_elements = []
layers = ["layer2", "layer5", "layer8", "layer11"]
for i in range(4):
    legend_elements.append(
        Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            label=layers[i],
            markerfacecolor=cm.colors[i],
            markersize=10,
            linestyle="",
        )
    )

plt.legend(handles=legend_elements, loc="lower right", fontsize=20, ncol=4)
plt.tight_layout()
plt.savefig("/workspace/figs/fig2_neuronprojectors.pdf")
plt.show()
# %%
s = "subj01"
mask_dir = f"/data/algonauts2023/{s}/roi_masks"
all_mask = []
for hemi in ["lh", "rh"]:
    mask_path = os.path.join(mask_dir, f"{hemi}.all-vertices_fsaverage_space.npy")
    mask = np.load(mask_path)
    all_mask.append(mask)
all_mask = np.concatenate(all_mask, axis=0)
# %%
### neuron projectors
def arr_creat(upperleft, upperright, lowerleft, lowerright):
    arr = np.linspace(
        np.linspace(lowerleft, lowerright, arrwidth),
        np.linspace(upperleft, upperright, arrwidth),
        arrheight,
        dtype=int,
    )
    return arr[:, :, None]


arrwidth = 256
arrheight = 256

r = arr_creat(0, 255, 0, 255)
g = arr_creat(0, 0, 255, 0)
b = arr_creat(255, 255, 0, 0)

img = np.concatenate([r, g, b], axis=2)
import scipy.ndimage
# img = scipy.ndimage.rotate(img, 90, reshape=False)
color_bar_img = img
plt.imshow(img)
plt.show()

img = torch.tensor(img).permute(2, 0, 1).float().unsqueeze(0)
from einops import rearrange, repeat
mu = mu_dict["NSD_01"]
mu = repeat(mu, "n c -> b n d c", b=1, d=1)
# %%
from torch.nn.functional import interpolate, grid_sample

c = grid_sample(img, mu, align_corners=True)
c = c.squeeze(0).squeeze(-1).t()
c = c.numpy().astype(int)
# %%
all_c = np.ones((all_mask.shape[0], 3))
all_c[all_mask == 1, :] = c
# %%
vertex = cortex.VertexRGB(
    all_c[:, 0],
    all_c[:, 1],
    all_c[:, 2],
    "fsaverage",
)

# %%
tmp_png = '/tmp/c.png'
cortex.quickshow(vertex, with_colorbar=False)
plt.savefig(tmp_png, dpi=300)
from PIL import Image
im = Image.open(tmp_png)
width, height = im.size
left = width * 0.35
right = width * 0.65
top = height * 0.2
bottom = height * 0.9

cropped_im = im.crop((left, top, right, bottom))

cropped_im.save(tmp_png)
# %%
plt.imshow(plt.imread(tmp_png))
plt.show()
# %%
fig, axs = plt.subplots(1, 2, figsize=(10, 7), width_ratios=[3, 1])
plt.sca(axs[0])
plt.imshow(plt.imread(tmp_png))
plt.axis("off")
plt.sca(axs[1])
ax = plt.gca()
plt.imshow(color_bar_img)
ax.grid(axis="both", linestyle="--", alpha=0.5)
ax.xaxis.set_major_locator(ticker.MultipleLocator(16))
ax.yaxis.set_major_locator(ticker.MultipleLocator(16))
ax.set_xticklabels([])
ax.set_yticklabels([])
for tick in ax.xaxis.get_major_ticks():
    tick.tick1line.set_visible(False)
    tick.tick2line.set_visible(False)
    tick.label1.set_visible(False)
    tick.label2.set_visible(False)
for tick in ax.yaxis.get_major_ticks():
    tick.tick1line.set_visible(False)
    tick.tick2line.set_visible(False)
    tick.label1.set_visible(False)
    tick.label2.set_visible(False)
# make ax2 smaller and put it to the right

plt.tight_layout()
plt.savefig("/workspace/figs/fig2_neuronprojectors_nsd_01.pdf")
plt.show()
# %%
### layer gate
gate = gate_dict["NSD_01"]
pngs = []
for i in range(4):
    g = gate[:, i]
    all_c = np.zeros((all_mask.shape[0]))
    all_c[all_mask == 1] = g
    vertex = cortex.Vertex(all_c, "fsaverage", vmin=0, vmax=1, cmap="viridis")
    
    tmp_png = f'/tmp/c{i}.png'
    cortex.quickshow(vertex, with_colorbar=True)
    plt.savefig(tmp_png, dpi=300)
    plt.close()
    
    from PIL import Image
    im = Image.open(tmp_png)
    width, height = im.size
    left = width * 0.35
    right = width * 0.65
    top = height * 0.2
    bottom = height * 0.99 

    cropped_im = im.crop((left, top, right, bottom))

    cropped_im.save(tmp_png)
# %%
layers = ["layer2", "layer5", "layer8", "layer11"]

fig, axs = plt.subplots(1, 4, figsize=(16, 6))
for i in range(4):
    plt.sca(axs[i])
    plt.imshow(plt.imread(f"/tmp/c{i}.png"))
    plt.axis("off")
    plt.title(layers[i], fontsize=20)
plt.tight_layout()
plt.savefig("/workspace/figs/fig2_layergate_nsd_01.pdf")
plt.show()
# %%
