# %%
import argparse
import os
from pathlib import Path

import matplotlib
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser
import h5py
from pathlib import Path

OLD_NAMES = [f"sub-{i:02d}" for i in range(1, 4)]
NEW_NAMES = [f"fMRI1_{i:02d}" for i in range(1, 4)]

ROIS = [
    "V1",
    "V2",
    "V3",
    "hV4",
    "VO1",
    "VO2",
    "LO1 (prf)",
    "LO2 (prf)",
    "V3b",
    "V3a",
    "lEBA",
    "rEBA",
    "lFFA",
    "rFFA",
]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/nas/ThingsfMRI1/betas_csv")
    parser.add_argument("--img_dir", type=str, default="/data/things/all_images")
    parser.add_argument("--save_dir", type=str, default="/data/VWE")
    parser.add_argument("--resolution", type=int, default=288)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.1)
    parser.add_argument("--val2_split", type=float, default=0.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


#
# def load_fmri_data(data_dir):
#     fmri_dir = os.path.join(data_dir, "training_split", "training_fmri")
#     lh_fmri = np.load(os.path.join(fmri_dir, "lh_training_fmri.npy"))
#     rh_fmri = np.load(os.path.join(fmri_dir, "rh_training_fmri.npy"))
#     return lh_fmri, rh_fmri


# def get_image_list(data_dir):
#     train_img_dir = os.path.join(data_dir, "training_split", "training_images")
#     test_img_dir = os.path.join(data_dir, "test_split", "test_images")

#     # Create lists will all training and test image file names, sorted
#     train_img_list = os.listdir(train_img_dir)
#     train_img_list.sort()
#     train_img_list = [os.path.join(train_img_dir, x) for x in train_img_list]
#     test_img_list = os.listdir(test_img_dir)
#     test_img_list.sort()
#     test_img_list = [os.path.join(test_img_dir, x) for x in test_img_list]
#     return train_img_list, test_img_list


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=100, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


# %%
if __name__ == "__main__":
    args = get_args("")

    # process images
    for i, sub in enumerate(OLD_NAMES):
        sti_df = pd.read_csv(os.path.join(args.orig_dir, f"{sub}_StimulusMetadata.csv"))

        all_img_list = sti_df["stimulus"]
        big_train_img_list = sti_df[sti_df["trial_type"] == "train"]["stimulus"]
        train_mask = (sti_df["trial_type"] == "train").values
        # train_indices = np.where(train_mask)[0]
        # big_train_session_ids = sti_df[sti_df["trial_type"] == "train"]["session"].tolist()
        # big_train_session_ids = [str(x) for x in big_train_session_ids]
        test_img_list = sti_df[sti_df["trial_type"] == "test"]["stimulus"]
        test_mask = (sti_df["trial_type"] == "test").values
        # test_indices = np.where(test_mask)[0]
        # test_session_ids = sti_df[sti_df["trial_type"] == "test"]["session"].tolist()
        # test_session_ids = [str(x) for x in test_session_ids]
        img_dir = args.img_dir
        all_img_list = [os.path.join(img_dir, x) for x in all_img_list]
        big_train_img_list = [os.path.join(img_dir, x) for x in big_train_img_list]
        test_img_list = [os.path.join(img_dir, x) for x in test_img_list]
        slim_test_img_list = np.unique(test_img_list)

        # break

        # big_img_list = sti_df["stimulus"].tolist()
        # big_img_list = [os.path.join(img_dir, x) for x in big_img_list]
        # big_sess_list = sti_df["session"].tolist()
        # big_sess_list = [str(x) for x in big_sess_list]

        # save to new image dir
        save_dir = os.path.join(args.save_dir, NEW_NAMES[i])
        all_img_list = renew_list(
            args, os.path.join(save_dir, "image/all"), all_img_list
        )
        slim_test_img_list = renew_list(
            args, os.path.join(save_dir, "image/all"), slim_test_img_list
        )

        train_indices = np.where(train_mask)[0]
        np.random.seed(args.seed)
        np.random.shuffle(train_indices)

        val1_indices = train_indices[: int(len(train_indices) * args.val1_split)]
        remaining_train_indices = [i for i in train_indices if i not in val1_indices]

        val1_image_list = [all_img_list[i] for i in val1_indices]
        train_img_list = [all_img_list[i] for i in remaining_train_indices]

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(i + "\n")

        for name, img_list in [
            ("train", train_img_list),
            ("val2", slim_test_img_list),
            ("val1", val1_image_list),
            # ("val2", val2_image_list),
        ]:
            save_list(
                img_list, os.path.join(save_dir, preffix + f"{name}_img_list.txt")
            )

        # for name, sess_list in [
        #     ("train", train_session_ids),
        #     ("val2", test_session_ids),
        #     ("val1", val1_session_ids),
        #     # ("val2", val2_session_ids),
        # ]:
        #     save_list(
        #         sess_list, os.path.join(save_dir, preffix + f"{name}_session_ids.txt")
        #     )

        # path = os.path.join(save_dir, f"train_y_list.txt")
        # if os.path.exists(path) and not args.overwrite:
        #     print(f"Skipping {sub} because train_y_list already exists")
        #     continue

        vox_f = os.path.join(args.orig_dir, f"{sub}_VoxelMetadata.csv")
        vox_df = pd.read_csv(vox_f)

        rois = [r for r in ROIS if r in vox_df.columns]
        mask = vox_df[rois].sum(axis=1) == 1
        voxel_indices = np.where(mask)[0]

        # break
        # # coord
        xyz = ["voxel_x", "voxel_y", "voxel_z"]
        coords = vox_df[xyz].values[voxel_indices]
        print(coords.shape)

        np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)

        if os.path.exists(os.path.join(save_dir, "fmri")) and not args.overwrite:
            print(f"Skipping {sub} because fmri data already exists")
            continue
        path = os.path.join(args.orig_dir, f"{sub}_ResponseData.h5")
        print("Loading ", path)
        responses_df = pd.read_hdf(path)

        drop_col = ["voxel_id"]
        responses_df = responses_df.drop(drop_col, axis=1)

        fmri_data = responses_df.values

        assert fmri_data.shape[1] == len(sti_df), f"{fmri_data.shape[1]} != {len(sti_df)}"

        # dummy_fmri_data = np.empty((mask.shape[0], len(sti_df)))
        print(fmri_data.shape)
        fmri_data = fmri_data[voxel_indices, :]
        print(fmri_data.shape)

        train_fmri = fmri_data[:, remaining_train_indices]
        val1_fmri = fmri_data[:, val1_indices]
        # test_fmri = fmri_data[:, test_indices]

        # average test fmri
        new_test_fmri = []
        test_stimu = sti_df[sti_df["trial_type"] == "test"]["stimulus"]
        all_stimu = sti_df["stimulus"]
        for img in np.unique(test_stimu):
            ind = np.where(all_stimu == img)[0]
            # print(img, ind)
            fmri = fmri_data[:, ind].mean(axis=1)
            # print(fmri.shape, fmri[:10])
            new_test_fmri.append(fmri)
        new_test_fmri = np.stack(new_test_fmri, axis=1)
        test_fmri = new_test_fmri
        print(test_fmri.shape)
        print(len(voxel_indices), len(slim_test_img_list))
        assert test_fmri.shape == (len(voxel_indices), len(slim_test_img_list))

        mean, std = train_fmri.mean(), train_fmri.std()
        train_fmri = (train_fmri - mean) / std
        val1_fmri = (val1_fmri - mean) / std
        test_fmri = (test_fmri - mean) / std

        def save_to_npy(fmri, indices):
            ret_list = []
            fmri = fmri.T
            fmri_dir = os.path.join(save_dir, "fmri")
            os.makedirs(fmri_dir, exist_ok=True)
            for i in tqdm(
                range(fmri.shape[0]), desc="Saving fmri data to: " + fmri_dir
            ):
                idx = indices[i]
                path = os.path.join(fmri_dir, f"{idx:010d}.npy")
                np.save(path, fmri[i].astype(np.float16))
                ret_list.append(path)
            return ret_list

        virtual_test_indices = (
            np.arange(len(slim_test_img_list)) + train_indices.max() + 1
        )
        for name, fmri, indices in [
            ("train", train_fmri, remaining_train_indices),
            ("val2", test_fmri, virtual_test_indices),
            ("val1", val1_fmri, val1_indices),
            # ("val2", val2_fmri, val2_indices),
        ]:
            save_list(
                save_to_npy(fmri, indices),
                os.path.join(save_dir, f"{name}_y_list.txt"),
            )

# %%
