# %%
import argparse
import os
from pathlib import Path

import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser

from pathlib import Path

import nilearn
import nibabel as nib

OLD_NAMES = [f"CSI{i}" for i in range(1, 5)]
NEW_NAMES = [f"B5K_{i:02d}" for i in range(1, 5)]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/nas/BOLD5000")
    parser.add_argument(
        "--roi_mask_dir",
        type=str,
        default="/nas/BOLD5000/openneuron/ds001499-download/derivatives/spm",
    )
    parser.add_argument("--img_dir", type=str, default="/nas/BOLD5000/images")
    parser.add_argument("--save_dir", type=str, default="/data/VWE")
    parser.add_argument("--resolution", type=int, default=288)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.1)
    parser.add_argument("--val2_split", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=100, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


def scan_subdir_for_image(subdir):
    """Scan a directory for images, return a list of image paths"""
    img_list = []
    
    def is_img(x):
        return (
            x.endswith(".jpg")
            or x.endswith(".png")
            or x.endswith(".jpeg")
            or x.endswith(".JPG")
            or x.endswith(".PNG")
            or x.endswith(".JPEG")
            or x.endswith(".tif")
            or x.endswith(".tiff")
            or x.endswith(".TIF")
            or x.endswith(".TIFF")
        )

    for root, dirs, files in os.walk(subdir):
        for file in files:
            if is_img(file):
                img_list.append(os.path.join(root, file))
    return img_list

def read_list_from_file(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
    return [x.strip() for x in lines]

# %%
if __name__ == "__main__":
    args = get_args()

    

    # process images
    for i_sub, sub in enumerate(OLD_NAMES):
        skip_flag = False
        if os.path.exists(os.path.join(args.save_dir, f"{NEW_NAMES[i_sub]}", "fmri")) and not args.overwrite:
            print(f"Skip {sub} because it already exists")
            skip_flag = True
            continue
            
        img_list = scan_subdir_for_image(args.img_dir)
        img_base_list = [os.path.basename(x) for x in img_list]
        stim_base_list = read_list_from_file(os.path.join(args.orig_dir, f"{sub}_imgnames.txt"))
        stim_list = []
        for sb in stim_base_list:
            stim_list.append(img_list[img_base_list.index(sb)])
        
        big_img_list = stim_list
        
        # load brainmask
        brainmask_dir = os.path.join(args.roi_mask_dir, f"sub-{sub}")
        brainmask_files = []
        for root, dirs, files in os.walk(brainmask_dir):
            for file in files:
                if file.startswith("sub-") and file.endswith(".nii.gz"):
                    brainmask_files.append(os.path.join(root, file))
        
        all_brainmasks = []
        for i, f in enumerate(brainmask_files):
            data = nib.load(f)
            data = data.get_fdata()
            all_brainmasks.append(data)
            
        all_brainmasks = np.stack(all_brainmasks, axis=-1)
        
        mask = all_brainmasks.sum(-1) > 0
        
        x, y, z = np.nonzero(mask)
        coords = np.stack([x, y, z], axis=-1)
                
        from einops import rearrange
        mask = rearrange(mask, 'x y z -> (x y z)')
        mask_indices = np.where(mask)[0]
        
        # find corresponding session
        pfx = f"{sub}_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_ses-"
        fmri_files = []
        for root, dirs, files in os.walk(args.orig_dir):
            for file in files:
                if file.startswith(pfx) and file.endswith(".nii.gz"):
                    fmri_files.append(os.path.join(root, file))
        fmri_files.sort()
        
        all_fmris = []
        all_session_ids = []
        for i, f in enumerate(fmri_files):
            fmri = nib.load(f)
            fmri = fmri.get_fdata()
            # print(fmri.shape)
            fmri = rearrange(fmri, 'x y z t -> (x y z) t')
            t = fmri.shape[-1]
            masked_fmri = fmri[mask_indices, :]
            # print(masked_fmri.shape)
            all_fmris.append(masked_fmri)
            all_session_ids += [str(i+1)] * t
            
        all_fmris = np.concatenate(all_fmris, axis=1)
        
        assert len(all_session_ids) == len(big_img_list)
        
        big_sess_list = all_session_ids
            
        # save to new image dir
        save_dir = os.path.join(args.save_dir, NEW_NAMES[i_sub])
        big_img_list = renew_list(
            args, os.path.join(save_dir, "image/all"), big_img_list
        )

        train_indices = np.arange(len(big_img_list))
        
        np.random.seed(args.seed)
        np.random.shuffle(train_indices)

        val1_indices = train_indices[: int(len(train_indices) * args.val1_split)]
        val2_indices = train_indices[
            int(len(train_indices) * args.val1_split) : int(
                len(train_indices) * (args.val1_split + args.val2_split)
            )
        ]
        remaining_train_indices = [
            i for i in train_indices if i not in val1_indices and i not in val2_indices
        ]

        val1_image_list = [big_img_list[i] for i in val1_indices]
        val2_image_list = [big_img_list[i] for i in val2_indices]
        val1_session_ids = [big_sess_list[i] for i in val1_indices]
        val2_session_ids = [big_sess_list[i] for i in val2_indices]
        train_img_list = [big_img_list[i] for i in remaining_train_indices]
        train_session_ids = [big_sess_list[i] for i in remaining_train_indices]

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(i + "\n")

        for name, img_list in [
            ("train", train_img_list),
            ("val1", val1_image_list),
            ("val2", val2_image_list),
        ]:
            save_list(
                img_list, os.path.join(save_dir, preffix + f"{name}_img_list.txt")
            )

        for name, sess_list in [
            ("train", train_session_ids),
            ("val1", val1_session_ids),
            ("val2", val2_session_ids),
        ]:
            save_list(
                sess_list, os.path.join(save_dir, preffix + f"{name}_session_ids.txt")
            )


        # dummy_fmri_data = np.empty((mask.shape[0], len(sti_df)))
        fmri_data = all_fmris

        train_fmri = fmri_data[:, remaining_train_indices]
        val1_fmri = fmri_data[:, val1_indices]
        val2_fmri = fmri_data[:, val2_indices]
        
        mean, std = train_fmri.mean(), train_fmri.std()
        train_fmri = (train_fmri - mean) / std
        val1_fmri = (val1_fmri - mean) / std
        val2_fmri = (val2_fmri - mean) / std

        def save_to_npy(fmri, indices):
            ret_list = []
            fmri = fmri.T
            fmri_dir = os.path.join(save_dir, "fmri")
            os.makedirs(fmri_dir, exist_ok=True)
            for i in tqdm(
                range(fmri.shape[0]), desc="Saving fmri data to: " + fmri_dir
            ):
                idx = indices[i]
                path = os.path.join(fmri_dir, f"{idx:010d}.npy")
                np.save(path, fmri[i].astype(np.float16))
                ret_list.append(path)
            return ret_list

        for name, fmri, indices in [
            ("train", train_fmri, remaining_train_indices),
            ("val1", val1_fmri, val1_indices),
            ("val2", val2_fmri, val2_indices),
        ]:
            save_list(
                save_to_npy(fmri, indices),
                os.path.join(save_dir, f"{name}_y_list.txt"),
            )

        # # coord
        np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)