# %%
import os

import numpy as np
import nilearn
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import cm
import cortex

# %%
cortex.download_subject("fsaverage", download_again=False)
# %%
sub = "subj03"
mysub = "NSD_03"
nc_dir = f"/nas/natural-scenes-dataset/nsddata_betas/ppdata/{sub}/fsaverage/betas_fithrf_GLMdenoise_RR"
nc = []
for hemi in ["lh", "rh"]:
    nc_path = os.path.join(nc_dir, f"{hemi}.ncsnr.mgh")
    data = nib.load(nc_path).get_fdata()
    data = np.squeeze(data)
    data = (data**2) / (data**2 + 1)
    nc.append(data)
nc = np.concatenate(nc, axis=0)

vertex_data = cortex.Vertex(nc, "fsaverage", vmax=0.6, vmin=0)

fig = cortex.quickflat.make_figure(
    vertex_data,
    with_curvature=False,
    with_rois=False,
    with_labels=False,
    with_sulci=False,
    with_colorbar=False,
)
plt.title(sub)
plt.show()

# %%
s = sub

roi_num_dict = {
    # "veroi_lll": 4,
    # "veroi_ll": 7,
    # "veroi_l": 11,
    "veroi_m": 18,
    # "random_m": 18,
    # "veroi_s": 37, 
    # "veroi_ss": 109,
    # "veroi_sss": 268,
    # "veroi_extreme": 1000,
}

for roi_prefix, roi_num in roi_num_dict.items():
    pngs = []
    for i in range(1, 9):
        s = f"subj{i:02d}"
        mysub = f"NSD_{i:02d}"

        mask_dir = f"/data/algonauts2023/{s}/roi_masks"
        all_mask = []
        for hemi in ["lh", "rh"]:
            mask_path = os.path.join(mask_dir, f"{hemi}.all-vertices_fsaverage_space.npy")
            mask = np.load(mask_path)
            all_mask.append(mask)
        all_mask = np.concatenate(all_mask, axis=0)

        roi_dir = f"/data/VWET/{mysub}/roi/"
        # roi_prefix = "veroi_m_"
        # roi_num = 18
        roi_arr = np.zeros(all_mask.sum())
        for i in range(1, roi_num + 1):
            vi = np.load(os.path.join(roi_dir, f"{roi_prefix}_{i}.npy"))
            roi_arr[vi] = i
        roi_arr[0] = 1

        full_roi_arr = np.zeros(all_mask.shape[0])
        full_roi_arr[all_mask == 1] = roi_arr

        vertex_data = cortex.Vertex(full_roi_arr, "fsaverage", cmap="rainbow", vmax=roi_num, vmin=0)

        png_path = f"/tmp/{mysub}.png"
        fig = cortex.quickflat.make_figure(
            vertex_data,
            with_curvature=False,
            with_rois=False,
            with_labels=False,
            with_sulci=False,
            with_colorbar=True,
        )
        plt.title(mysub)
        plt.savefig(png_path)
        plt.close()
        
        tmp_png = png_path
        from PIL import Image
        im = Image.open(tmp_png)
        width, height = im.size
        left = width * 0.35
        right = width * 0.65
        top = height * 0.2
        bottom = height * 0.99

        cropped_im = im.crop((left, top, right, bottom))

        cropped_im.save(tmp_png)
        
        pngs.append(tmp_png)
    # %%
    fig, axs = plt.subplots(2, 4, figsize=(14, 9))
    for i in range(8):
        ax = axs[i // 4, i % 4]
        ax.imshow(plt.imread(pngs[i]))
        ax.axis("off")
        ax.set_title(f"NSD_{i+1:02d}", fontsize=30)
    # plt.suptitle(roi_prefix.upper())
    plt.tight_layout()
    plt.savefig(f"/workspace/figs/cortexroi.pdf")
    plt.show()

# %%
np.unique(data, return_counts=True)
# %%
# %%
v = []
for hemi in ["lh", "rh"]:
    mask_path = os.path.join(mask_dir, f"{hemi}.streams_challenge_space.npy")
    mask = np.load(mask_path)
    v.append(mask)
v = np.concatenate(v, axis=0)
early = v == 1
mid = (v == 2) | (v == 3) | (v == 4)
late = (v == 5) | (v == 6) | (v == 7) | (v == 0)
# %%
all_v = np.zeros(all_mask.shape[0])
all_v[all_mask == 1] = v
# %%
vertex_data = cortex.Vertex(all_v, "fsaverage", cmap="Set1")
# %%
cortex.quickshow(vertex_data, with_colorbar=False)
plt.show()
# %%
v.shape
# %%
def color_from_array_value(arr, cmap="Set1"):
    cmap = cm.get_cmap(cmap)
    arr = arr - arr.min()
    arr = arr / arr.max()
    arr = arr * (cmap.N - 1)
    arr = arr.astype(int)
    colors = cmap(arr)
    return colors


# %%
def color_from_subject(subject_id="subj01", cmap="Set1"):
    mask_dir = f"/data/algonauts2023/{subject_id}/roi_masks"
    v = []
    for hemi in ["lh", "rh"]:
        mask_path = os.path.join(mask_dir, f"{hemi}.streams_challenge_space.npy")
        mask = np.load(mask_path)
        v.append(mask)
    v = np.concatenate(v, axis=0)

    cmap = cm.get_cmap(cmap)
    arr = arr - arr.min()
    arr = arr / arr.max()
    arr = arr * (cmap.N - 1)
    arr = arr.astype(int)
    colors = cmap(arr)

    return colors


# %%
colors = color_from_subject("subj01")
# %%
colors
# %%
curv = cortex.db.get_surfinfo("fsaverage")
# %%
curv
# %%
cortex.quickshow(curv, with_colorbar=False)
plt.show()
# %%
curv.data.shape
# %%
all_v.shape
# %%
early
# %%
curv.data[all_mask == 1].shape
# %%
data = curv.data[all_mask == 1]
# %%
np.save("/tmp/curv.npy", data)

# %%
curv.cmap
# %%
coords = np.load("/data/VWE/NSD_01/neuron_coords.npy")
# %%
coords.shape
# %%
v = (
    coords[:, 0]
    + coords[:, 1] * coords[:, 0].max()
    + coords[:, 2] * coords[:, 0].max() * coords[:, 1].max()
)
# %%
all_v = np.zeros(all_mask.shape[0])
# %%
all_v[all_mask == 1] = v
# %%
vertex_data = cortex.Vertex(all_v, "fsaverage", cmap="prism")
# %%
cortex.quickshow(vertex_data, with_colorbar=True)
plt.show()
# %%
mus = np.load("/tmp/mus.npy", allow_pickle=True)
mu = mus[-1]
mu = mu * 1.5
mu = torch.clamp(mu, -1, 1)
import matplotlib.pyplot as plt
import numpy as np


def arr_creat(upperleft, upperright, lowerleft, lowerright):
    arr = np.linspace(
        np.linspace(lowerleft, lowerright, arrwidth),
        np.linspace(upperleft, upperright, arrwidth),
        arrheight,
        dtype=int,
    )
    return arr[:, :, None]


arrwidth = 256
arrheight = 256

r = arr_creat(0, 255, 0, 255)
g = arr_creat(0, 0, 255, 0)
b = arr_creat(255, 255, 0, 0)

img = np.concatenate([r, g, b], axis=2)
import scipy.ndimage
# img = scipy.ndimage.rotate(img, 90, reshape=False)
plt.imshow(img)
plt.show()
# %%
img = torch.tensor(img).permute(2, 0, 1).float().unsqueeze(0)
# %%
from einops import rearrange, repeat

mu = repeat(mu, "n c -> b n d c", b=1, d=1)
# %%
from torch.nn.functional import interpolate, grid_sample

c = grid_sample(img, mu, align_corners=True)
# %%
c = c.squeeze(0).squeeze(-1).t()
# %%
c = c.numpy().astype(int)
# %%
# %%
all_c = np.zeros((all_mask.shape[0], 3))
all_c[all_mask == 1, :] = c
# %%
vertex = cortex.VertexRGB(
    all_c[:, 0],
    all_c[:, 1],
    all_c[:, 2],
    "fsaverage",
)

# %%
cortex.quickshow(vertex, with_colorbar=False)
plt.show()
# %%
