# %%
import argparse
import logging
import os
from pathlib import Path

import matplotlib
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser
import h5py
from pathlib import Path

import nilearn
import nibabel as nib

from torchvision import transforms

OLD_NAMES = [f"sub{i:02d}" for i in range(1, 11)]
NEW_NAMES = [f"ALG"]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/nas/algonauts2021")
    parser.add_argument("--save_dir", type=str, default="/data/VWET")
    parser.add_argument("--resolution", type=int, default=224)
    parser.add_argument("--frames", type=int, default=10)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.1)
    parser.add_argument("--val2_split", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


from decord import VideoReader, cpu, gpu


def load_video(file, frames=10, resolution=224):
    resize = transforms.Resize((resolution, resolution))
    vr = VideoReader(file, ctx=cpu())
    total_frames = len(vr)
    indices = np.linspace(0, total_frames - 1, frames, dtype=int)
    images = []
    for seg_ind in indices:
        img = vr[seg_ind].asnumpy()
        img = resize(Image.fromarray(img))
        img = np.array(img)
        images.append(img)
    vid = np.stack(images, 0)
    return vid


def get_roi_indices(roi_data, fb_data):
    roi_indices = []
    for i in tqdm(range(fb_data.shape[1]), desc="Finding ROI indices"):
        for j in range(roi_data.shape[1]):
            if np.all(fb_data[:, i] == roi_data[:, j]):
            # if fb_data[:, i].mean() == roi_data[:, j].mean():
                roi_indices.append(i)
                # print(f"Found ROI index {i} for ROI {j}")
                break
    print(f"Found {len(roi_indices)} ROI indices")
    print(f"Expected {roi_data.shape[1]} ROI indices")
    return roi_indices


# %%
if __name__ == "__main__":
    # args = get_args("--overwrite".split(" "))
    args = get_args()

    movie_dir = os.path.join(args.orig_dir, "AlgonautsVideos268_All_30fpsmax/")
    MOVIES = os.listdir(movie_dir)
    MOVIES = list(MOVIES)
    MOVIES.sort()
    MOVIES = MOVIES[:1000]
    movie_outer_list = []
    save_dir = os.path.join(
        args.save_dir,
        NEW_NAMES[0],
        f"movie_{args.resolution}x{args.resolution}_{args.frames}frames",
    )
    load_flag = True
    if os.path.exists(save_dir) and not args.overwrite:
        load_flag = False
    os.makedirs(save_dir, exist_ok=True)
    for i_movie, movie in tqdm(enumerate(MOVIES), desc="Movies"):
        movie_path = os.path.join(movie_dir, movie)
        outer_dir = f"{movie.split('.')[0]}"
        outer_dir = os.path.join(save_dir, outer_dir)
        movie_outer_list.append(outer_dir)
        if load_flag:
            vid = load_video(movie_path, frames=args.frames, resolution=args.resolution)
            inner_list = []
            for i_frame in range(args.frames):
                os.makedirs(outer_dir, exist_ok=True)
                img_arr = vid[i_frame]
                im = Image.fromarray(img_arr)
                im.convert("RGB")
                im.save(os.path.join(outer_dir, f"{i_frame}.{args.save_fmt}"))


# %%
if __name__ == "__main__":
    
    all_fmri_data = []
    all_coords = []
    coords_max = 0
    for i_sub, sub in enumerate(OLD_NAMES):

        data_dir = os.path.join(args.orig_dir, "participants_data_v2021", "mini_track", sub)
        files = os.listdir(data_dir)
        roi_data = []
        for f in files:
            data = np.load(os.path.join(data_dir, f), allow_pickle=True)
            data = data['train']
            data = data.mean(1) # repeat
            roi_data.append(data)
        roi_data = np.concatenate(roi_data, axis=1)
        fb_data = np.load(os.path.join(args.orig_dir, "participants_data_v2021", "full_track", sub, "WB.pkl"), allow_pickle=True)
        fb_mask = fb_data['voxel_mask']
        xyz = np.nonzero(fb_mask)
        coords = np.stack(xyz, axis=1)
        fb_data = fb_data['train']
        fb_data = fb_data.mean(1)
        
        roi_indices = get_roi_indices(roi_data, fb_data)
        # roi_indices = np.arange(10)
        
        coords = coords[roi_indices]
        
        coords_max += coords.max()
        coords += coords_max
        all_coords.append(coords)

        fmri_data = fb_data[:, roi_indices]
        all_fmri_data.append(fmri_data)
        
    fmri_data = np.concatenate(all_fmri_data, axis=1)
    coords = np.concatenate(all_coords, axis=0)

    save_dir = os.path.join(args.save_dir, NEW_NAMES[0])
    
    os.makedirs(save_dir, exist_ok=True)

    train_indices = np.arange(fmri_data.shape[0])

    np.random.seed(args.seed)
    np.random.shuffle(train_indices)

    val1_indices = train_indices[: int(len(train_indices) * args.val1_split)]
    val2_indices = train_indices[
        int(len(train_indices) * args.val1_split) : int(
            len(train_indices) * (args.val1_split + args.val2_split)
        )
    ]
    remaining_train_indices = [
        i for i in train_indices if i not in val1_indices and i not in val2_indices
    ]

    # assert len(val2_indices) == 0  # use test set

    # time delay
    big_img_list = movie_outer_list

    val1_image_list = [big_img_list[i] for i in val1_indices]
    val2_image_list = [big_img_list[i] for i in val2_indices]
    train_img_list = [big_img_list[i] for i in remaining_train_indices]

    preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

    def save_list(l, name):
        with open(os.path.join(save_dir, name), "w") as f:
            for i in l:
                f.write(str(i) + "\n")

    for name, img_list in [
        ("train", train_img_list),
        ("val1", val1_image_list),
        ("val2", val2_image_list),
    ]:
        save_list(
            img_list, os.path.join(save_dir, preffix + f"{name}_img_list.txt")
        )

    fmri_data = fmri_data.T
    train_fmri = fmri_data[:, remaining_train_indices]
    val1_fmri = fmri_data[:, val1_indices]
    val2_fmri = fmri_data[:, val2_indices]

    def save_to_npy(fmri, indices):
        ret_list = []
        fmri = fmri.T
        fmri_dir = os.path.join(save_dir, "fmri")
        os.makedirs(fmri_dir, exist_ok=True)
        for i in tqdm(
            range(fmri.shape[0]), desc="Saving fmri data to: " + fmri_dir
        ):
            idx = indices[i]
            path = os.path.join(fmri_dir, f"{idx:010d}.npy")
            np.save(path, fmri[i].astype(np.float16))
            ret_list.append(path)
        return ret_list

    for name, fmri, indices in [
        ("train", train_fmri, remaining_train_indices),
        ("val1", val1_fmri, val1_indices),
        ("val2", val2_fmri, val2_indices),
    ]:
        save_list(
            save_to_npy(fmri, indices),
            os.path.join(save_dir, f"{name}_y_list.txt"),
        )

    np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)


# %%
