# %%
import argparse
import os
from pathlib import Path

import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser
import nilearn
import nibabel as nib
from pathlib import Path

OLD_NAMES = [f"subj{i:02d}" for i in range(1, 9)]
NEW_NAMES = [f"NSD_{i:02d}" for i in range(1, 9)]

MAGIC_NUMBER = 114514  # for reproducibility


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/data/algonauts2023")
    parser.add_argument("--full_dir", type=str, default="/nas/natural-scenes-dataset")
    parser.add_argument(
        "--orig_design_dir",
        type=str,
        default="/nas/natural-scenes-dataset/nsddata_timeseries/ppdata/subj01/func1pt8mm/design",
    )
    parser.add_argument("--save_dir", type=str, default="/data/VWET")
    parser.add_argument("--resolution", type=int, default=288)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    # parser.add_argument("--val1_split", type=float, default=0.0114514)
    # parser.add_argument("--val2_split", type=float, default=0.0114514)
    parser.add_argument("--val1_count", type=int, default=500)
    parser.add_argument("--val2_count", type=int, default=300)
    parser.add_argument("--seed", type=int, default=MAGIC_NUMBER)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


def load_fmri_data(data_dir):
    fmri_dir = os.path.join(data_dir, "training_split", "training_fmri")
    lh_fmri = np.load(os.path.join(fmri_dir, "lh_training_fmri.npy"))
    rh_fmri = np.load(os.path.join(fmri_dir, "rh_training_fmri.npy"))
    return lh_fmri, rh_fmri


def get_image_list(data_dir):
    train_img_dir = os.path.join(data_dir, "training_split", "training_images")
    test_img_dir = os.path.join(data_dir, "test_split", "test_images")

    # Create lists will all training and test image file names, sorted
    train_img_list = os.listdir(train_img_dir)
    train_img_list.sort()
    train_img_list = [os.path.join(train_img_dir, x) for x in train_img_list]
    test_img_list = os.listdir(test_img_dir)
    test_img_list.sort()
    test_img_list = [os.path.join(test_img_dir, x) for x in test_img_list]
    return train_img_list, test_img_list


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=90, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


def extract_nsd_idx_from_image_path(image_path):
    image_name = os.path.basename(image_path)
    left = "nsd-"
    right = "_"
    nsd_idx = image_name[image_name.find(left) + len(left) : image_name.rfind(right)]
    return int(nsd_idx)


# %%
if __name__ == "__main__":
    args = get_args()

    # process images
    for i, sub in enumerate(OLD_NAMES):
        data_dir = os.path.join(args.orig_dir, sub)
        train_img_list, test_img_list = get_image_list(data_dir)
        N = len(train_img_list)

        np.random.seed(args.seed)
        train_idx_shuffle = np.arange(len(train_img_list))
        np.random.shuffle(train_idx_shuffle)

        # save to new image dir
        save_dir = os.path.join(args.save_dir, NEW_NAMES[i])
        train_img_list = renew_list(
            args, os.path.join(save_dir, "image/train"), train_img_list
        )
        test_img_list = renew_list(
            args, os.path.join(save_dir, "image/test"), test_img_list
        )

        train_img_list = [train_img_list[i] for i in train_idx_shuffle]
        test_img_list = test_img_list
        # val1_image_list = train_img_list[: int(len(train_img_list) * args.val1_split)]
        # val2_image_list = train_img_list[
        #     int(len(train_img_list) * args.val1_split) : int(
        #         len(train_img_list) * (args.val1_split + args.val2_split)
        #     )
        # ]
        val1_image_list = train_img_list[: args.val1_count]
        val2_image_list = train_img_list[
            args.val1_count : args.val1_count + args.val2_count
        ]
        train_img_list = [
            i
            for i in train_img_list
            if i not in val1_image_list and i not in val2_image_list
        ]

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(i + "\n")

        save_list(
            train_img_list, os.path.join(save_dir, preffix + "train_img_list.txt")
        )
        save_list(test_img_list, os.path.join(save_dir, preffix + "test_img_list.txt"))
        save_list(
            val1_image_list, os.path.join(save_dir, preffix + "val1_img_list.txt")
        )
        save_list(
            val2_image_list, os.path.join(save_dir, preffix + "val2_img_list.txt")
        )

        y_save_dir = os.path.join(save_dir, "fmri/train")
        skip_flag = False
        if os.path.exists(y_save_dir) and not args.overwrite:
            print(f"Skipping {sub} because {y_save_dir} already exists.")
            skip_flag = True

        # save to new fmri dir
        if not skip_flag:
            lh_fmri, rh_fmri = load_fmri_data(data_dir)
            train_fmri = np.concatenate((lh_fmri, rh_fmri), axis=1)

        os.makedirs(y_save_dir, exist_ok=True)
        train_y_list = []
        for i in tqdm(range(N), desc="Saving fmri data to: " + y_save_dir):
            path = os.path.join(y_save_dir, f"{i:010d}.npy")
            if not skip_flag:
                np.save(path, train_fmri[i].astype(np.float16))
            train_y_list.append(path)

        train_y_list = [train_y_list[i] for i in train_idx_shuffle]
        test_y_list = []
        # val1_y_list = train_y_list[: int(len(train_y_list) * args.val1_split)]
        # val2_y_list = train_y_list[
        #     int(len(train_y_list) * args.val1_split) : int(
        #         len(train_y_list) * (args.val1_split + args.val2_split)
        #     )
        # ]
        val1_y_list = train_y_list[: args.val1_count]
        val2_y_list = train_y_list[args.val1_count : args.val1_count + args.val2_count]
        train_y_list = [
            i for i in train_y_list if i not in val1_y_list and i not in val2_y_list
        ]

        save_list(train_y_list, os.path.join(save_dir, preffix + "train_y_list.txt"))
        save_list(test_y_list, os.path.join(save_dir, preffix + "test_y_list.txt"))
        save_list(val1_y_list, os.path.join(save_dir, preffix + "val1_y_list.txt"))
        save_list(val2_y_list, os.path.join(save_dir, preffix + "val2_y_list.txt"))

        if not skip_flag:
            # coord
            # data_dir = "/data/algonauts2023/subj01"
            mask_dir = os.path.join(data_dir, "roi_masks")
            lh_mask = np.load(
                os.path.join(mask_dir, "lh.all-vertices_fsaverage_space.npy")
            )
            rh_mask = np.load(
                os.path.join(mask_dir, "rh.all-vertices_fsaverage_space.npy")
            )
            print(lh_mask.shape, rh_mask.shape)

            import nilearn

            fsaverage = nilearn.datasets.fetch_surf_fsaverage("fsaverage7")
            lh_coords, lh_faces = nilearn.surface.load_surf_mesh(
                fsaverage["sphere_left"]
            )
            rh_coords, rh_faces = nilearn.surface.load_surf_mesh(
                fsaverage["sphere_right"]
            )
            lh_xmin, lh_xmax = np.min(lh_coords[:, 0]), np.max(lh_coords[:, 0])
            lh_xmax = lh_xmin + (lh_xmax - lh_xmin) * 1.5
            rh_xmin, rh_xmax = np.min(rh_coords[:, 0]), np.max(rh_coords[:, 0])
            if rh_xmin < lh_xmax:
                rh_coords[:, 0] += lh_xmax - rh_xmin
            print(lh_coords.shape, rh_coords.shape)
            lh_coords = lh_coords[lh_mask == 1]
            rh_coords = rh_coords[rh_mask == 1]
            coords = np.concatenate((lh_coords, rh_coords), axis=0)

            print(coords.shape)

            # for i in range(3):
            #     coords[:, i] = (coords[:, i] - coords[:, i].min()) / (
            #         coords[:, i].max() - coords[:, i].min()
            #     )

            np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)

# %%
if __name__ == "__main__":

    # noise ceiling
    for old_sub, new_sub in zip(OLD_NAMES, NEW_NAMES):
        nc_dir = os.path.join(
            args.full_dir,
            "nsddata_betas/ppdata",
            old_sub,
            "fsaverage/betas_fithrf_GLMdenoise_RR/",
        )

        data_dir = os.path.join(args.orig_dir, old_sub)
        mask_dir = os.path.join(data_dir, "roi_masks")
        lh_mask = np.load(os.path.join(mask_dir, "lh.all-vertices_fsaverage_space.npy"))
        rh_mask = np.load(os.path.join(mask_dir, "rh.all-vertices_fsaverage_space.npy"))
        mask = np.concatenate((lh_mask, rh_mask))

        # print(mask.shape)
        # exit()

        max_nc = []
        first_nc = []
        mean_nc = []
        for hemi in ["lh", "rh"]:
            sub_nc = []
            for nc_type in ["ncsnr", "ncsnr_split1", "ncsnr_split2"]:
                path = os.path.join(nc_dir, f"{hemi}.{nc_type}.mgh")
                d = nib.load(path).get_fdata()
                d = d.flatten()
                d = d**2 / (d**2 + 1 / 3)
                sub_nc.append(d.flatten())
            sub_nc = np.array(sub_nc)
            mean_nc.append(sub_nc.mean(axis=0))
            max_nc.append(sub_nc.max(axis=0))
            first_nc.append(sub_nc[0])
        max_nc = np.concatenate(max_nc)
        first_nc = np.concatenate(first_nc)
        mean_nc = np.concatenate(mean_nc)

        save_dir = os.path.join(args.save_dir, new_sub)

        nc = first_nc[mask == 1]

        np.save(os.path.join(save_dir, "nc.npy"), nc)


        # roi mask
        lh_mask = np.load(os.path.join(mask_dir, "lh.streams_challenge_space.npy"))
        rh_mask = np.load(os.path.join(mask_dir, "rh.streams_challenge_space.npy"))
        mask = np.concatenate((lh_mask, rh_mask))
        v = mask
        early = (v == 1)
        mid = (v == 2) | (v == 3) | (v == 4)
        late = (v == 5) | (v == 6) | (v == 7) | (v == 0)

        early = np.where(early)[0]
        mid = np.where(mid)[0]
        late = np.where(late)[0]

        roi_dir = os.path.join(save_dir, "roi")
        os.makedirs(roi_dir, exist_ok=True)
        np.save(os.path.join(roi_dir, "early.npy"), early)
        np.save(os.path.join(roi_dir, "mid.npy"), mid)
        np.save(os.path.join(roi_dir, "late.npy"), late)


# # %%
# import cortex

# # %%
# vertex = cortex.Vertex(first_nc, "fsaverage", vmin=0, vmax=1)
# cortex.quickshow(vertex, with_curvature=True)
# # %%
# nc = first_nc
# nc[nc < 0.2] = 0.2
# nc[mask == 0] = 0
# vertex = cortex.Vertex(nc, "fsaverage", vmin=0, vmax=1)
# cortex.quickshow(vertex, with_curvature=True)
# # %%
# plt.hist(nc[mask==1], bins=100)
# plt.show()
# # %%

# %%
