# %%
import argparse
import logging
import os
from pathlib import Path

import matplotlib
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser
import h5py
from pathlib import Path

import nilearn
import nibabel as nib

from torchvision import transforms

OLD_NAMES = []
NEW_NAMES = []

MOVIES = [
    "7T_MOVIE1_CC1_v2",
    "7T_MOVIE2_HO1_v2",
    "7T_MOVIE3_CC2_v2",
    "7T_MOVIE4_HO2_v2",
]
FRAMES = [921, 918, 915, 901]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/nas/HCP7T")
    parser.add_argument("--save_dir", type=str, default="/data/VWE")
    parser.add_argument("--resolution", type=int, default=224)
    parser.add_argument("--fps", type=int, default=10)
    parser.add_argument("--delay", type=int, default=4)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.034)
    parser.add_argument("--val2_split", type=float, default=0.034)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


from decord import VideoReader, cpu, gpu


def load_video(file, fps=4, resolution=224):
    resize = transforms.Resize((resolution, resolution))
    vr = VideoReader(file, ctx=cpu())
    total_frames = len(vr)
    length = int(total_frames / vr.get_avg_fps())
    num_frames = int(length * fps)
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    images = []
    for seg_ind in tqdm(indices, desc="Loading video"):
        img = vr[seg_ind].asnumpy()
        img = resize(Image.fromarray(img))
        img = np.array(img)
        images.append(img)
    vid = np.stack(images, 0)
    return vid


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=100, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


def get_visual_indices(parcel, visual):
    visual_indices = []
    for i in visual:
        visual_indices.append(np.where(parcel == i)[0])
    visual_indices = np.concatenate(visual_indices)
    visual_indices = np.unique(visual_indices)
    visual_indices.sort()
    return visual_indices

def shuffle_by_chuck(arr, chunk_size=20):
    n = len(arr)
    n_chunks = n // chunk_size
    arr_list = []
    for i in range(n_chunks):
        arr_list.append(arr[i * chunk_size : (i + 1) * chunk_size])
    arr_list.append(arr[n_chunks * chunk_size :])
    np.random.shuffle(arr_list)
    arr = np.concatenate(arr_list)
    return arr
# %%
if __name__ == "__main__":
    # args = get_args("--overwrite".split(" "))
    args = get_args()

    parcel = nib.load(
        os.path.join(
            args.orig_dir,
            "Q1-Q6_RelatedParcellation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.59k_fs_LR.dlabel.nii",
        )
    ).get_fdata()
    parcel = parcel[0, :]
    
    # print(parcel.shape)
    # exit()

    df = pd.read_csv(os.path.join(args.orig_dir, "Glasser_2016_Table.csv"))
    descs = df.values[:, 2]
    lhv = []
    for i_movie, desc in enumerate(descs):
        if "visual" in desc.lower():
            lhv.append(i_movie + 1)
    rhv = [i + 180 for i in lhv]
    visual = lhv + rhv

    visual_indices = get_visual_indices(parcel, visual)

    lh_path = os.path.join(args.orig_dir, "L.sphere.59k_fs_LR.surf.gii")
    lh_coords, lh_faces = nilearn.surface.load_surf_mesh(lh_path)
    rh_path = os.path.join(args.orig_dir, "R.sphere.59k_fs_LR.surf.gii")
    rh_coords, rh_faces = nilearn.surface.load_surf_mesh(rh_path)
    # rh_coords[:, 0] += lh_coords[:, 0].max() + 1
    coords = np.concatenate((lh_coords, rh_coords), axis=0)

    assert coords.shape[0] == parcel.shape[0]
    coords = coords[visual_indices]

    OLD_NAMES = os.listdir(os.path.join(args.orig_dir, "HCP"))
    NEW_NAMES = [f"HCP_{i+1:03d}" for i in range(len(OLD_NAMES))]

    movie_dir = os.path.join(args.orig_dir, "movie/Post_20140821_version/")
    movie_outer_list = []
    save_dir = os.path.join(
        args.save_dir,
        NEW_NAMES[0],
        f"movie_{args.resolution}x{args.resolution}_{args.fps}fps",
    )
    load_flag = True
    if os.path.exists(save_dir) and not args.overwrite:
        load_flag = False
    os.makedirs(save_dir, exist_ok=True)
    for i_movie, movie in enumerate(MOVIES):
        movie_path = os.path.join(movie_dir, movie + ".mp4")
        if load_flag:
            vid = load_video(movie_path, fps=args.fps, resolution=args.resolution)
            length = FRAMES[i_movie] * args.fps
            vid = vid[:length]
            inner_list = []
            for i_frame in tqdm(range(FRAMES[i_movie]), desc=movie + " frames"):
                if FRAMES[i_movie] - i_frame <= args.delay:
                    break
                outer_dir = f"{movie}_time{i_frame:04d}"
                outer_dir = os.path.join(save_dir, outer_dir)
                movie_outer_list.append(outer_dir)
                os.makedirs(outer_dir, exist_ok=True)
                for i_fps in range(args.fps):
                    img_arr = vid[i_frame * args.fps + i_fps]
                    im = Image.fromarray(img_arr)
                    im.convert("RGB")
                    im.save(os.path.join(outer_dir, f"{i_fps}.{args.save_fmt}"))
        else:
            for i_frame in tqdm(range(FRAMES[i_movie]), desc=movie + " frames"):
                if FRAMES[i_movie] - i_frame <= args.delay:
                    break
                outer_dir = f"{movie}_time{i_frame:04d}"
                outer_dir = os.path.join(save_dir, outer_dir)
                movie_outer_list.append(outer_dir)

# %%
if __name__ == "__main__":
    # process images
    i_sub = 0
    for sub in tqdm(OLD_NAMES, desc="subjects"):

        save_dir = os.path.join(args.save_dir, NEW_NAMES[i_sub])
        os.makedirs(save_dir, exist_ok=True)

        data_dir = os.path.join(args.orig_dir, "HCP", sub, "MNINonLinear", "Results")
        session_ids = os.listdir(data_dir)
        if len(session_ids) != 4:
            logging.warning(
                f"Skipping {sub} because it has {len(session_ids)} sessions"
            )
            continue

        i_sub += 1

        skip_flag = False
        if os.path.exists(os.path.join(save_dir, "fmri")) and not args.overwrite:
            print(f"Skipping {sub} because fmri data already exists")
            skip_flag = True

        if not skip_flag:
            import hcp_utils as hcp

            all_data = []
            for isess, sessid in enumerate(session_ids):
                fmri_dir = os.path.join(data_dir, sessid)
                for f in os.listdir(fmri_dir):
                    if "s_1.6mm_MSMAll_hp2000_clean.dtseries.nii" in f:
                        fmri_path = os.path.join(fmri_dir, f)
                        break

                data = nib.load(fmri_path).get_fdata()
                # print(data.shape)
                # print(parcel.shape)
                # print(visual_indices.shape)
                data = data[:, visual_indices]
                data = data[args.delay :, :]
                data = hcp.normalize(data)
                all_data.append(data)

            all_data = np.concatenate(all_data, axis=0)
            fmri_data = all_data.T # (n_voxel, n_time)
            assert fmri_data.shape[0] == len(visual_indices)
            if fmri_data.shape[1] != len(movie_outer_list):
                logging.warning(
                    f"Skipping {sub} because it has {fmri_data.shape[1]} frames"
                )
                i_sub -= 1
                continue

        train_indices = np.arange(len(movie_outer_list))

        np.random.seed(args.seed)
        # np.random.shuffle(train_indices)
        train_indices = shuffle_by_chuck(train_indices, 20)

        val1_indices = train_indices[: int(len(train_indices) * args.val1_split)]
        val2_indices = train_indices[
            int(len(train_indices) * args.val1_split) : int(
                len(train_indices) * (args.val1_split + args.val2_split)
            )
        ]
        remaining_train_indices = [
            i for i in train_indices if i not in val1_indices and i not in val2_indices
        ]

        # assert len(val2_indices) == 0  # use test set

        # time delay
        big_img_list = movie_outer_list
        big_sess_list = []
        for i_movie, frame in enumerate(FRAMES):
            big_sess_list += [i_movie] * frame

        val1_image_list = [big_img_list[i] for i in val1_indices]
        val2_image_list = [big_img_list[i] for i in val2_indices]
        val1_session_ids = [big_sess_list[i] for i in val1_indices]
        val2_session_ids = [big_sess_list[i] for i in val2_indices]
        train_img_list = [big_img_list[i] for i in remaining_train_indices]
        train_session_ids = [big_sess_list[i] for i in remaining_train_indices]

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(str(i) + "\n")

        for name, img_list in [
            ("train", train_img_list),
            ("val1", val1_image_list),
            ("val2", val2_image_list),
        ]:
            save_list(
                img_list, os.path.join(save_dir, preffix + f"{name}_img_list.txt")
            )

        for name, sess_list in [
            ("train", train_session_ids),
            ("val1", val1_session_ids),
            ("val2", val2_session_ids),
        ]:
            save_list(
                sess_list, os.path.join(save_dir, preffix + f"{name}_session_ids.txt")
            )

        if not skip_flag:
            def save_to_npy(fmri):
                ret_list = []
                fmri = fmri.T # (n_time, n_voxel)
                fmri_dir = os.path.join(save_dir, "fmri")
                os.makedirs(fmri_dir, exist_ok=True)
                for i in tqdm(
                    range(fmri.shape[0]), desc="Saving fmri data to: " + fmri_dir
                ):
                    idx = i
                    path = os.path.join(fmri_dir, f"{idx:010d}.npy")
                    np.save(path, fmri[i].astype(np.float16))
                    ret_list.append(path)
                return ret_list
            
            fmri_list = save_to_npy(fmri_data)
        else:
            fmri_list = []
            for i in range(len(big_img_list)):
                idx = i
                fmri_dir = os.path.join(save_dir, "fmri")
                path = os.path.join(fmri_dir, f"{idx:010d}.npy")
                fmri_list.append(path)
            
        train_fmri_list = [fmri_list[i] for i in remaining_train_indices]
        val1_fmri_list = [fmri_list[i] for i in val1_indices]
        val2_fmri_list = [fmri_list[i] for i in val2_indices]
        
        for name, fmri_list in [
            ("train", train_fmri_list),
            ("val1", val1_fmri_list),
            ("val2", val2_fmri_list),
        ]:
            save_list(
                fmri_list, os.path.join(save_dir, f"{name}_y_list.txt")
            )

        np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)


# %%
