# %%
import argparse
import os
from pathlib import Path

import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser

from pathlib import Path

OLD_NAMES = [f"sub-{i:02d}" for i in range(1, 11)]
NEW_NAMES = [f"EEG2_{i:02d}" for i in range(1, 11)]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument("--orig_dir", type=str, default="/data/things/eeg2")
    parser.add_argument("--save_dir", type=str, default="/data/VWE")
    parser.add_argument("--resolution", type=int, default=224)
    parser.add_argument("--clamp", type=float, default=20.0)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.1)
    # parser.add_argument("--val2_split", type=float, default=0.0) # use test set
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


def load_eeg_data(base_dir, clamp=20, t=100):
    ### Load the EEG training data ###
    training_file = "preprocessed_eeg_training.npy"
    data = np.load(os.path.join(base_dir, training_file), allow_pickle=True).item()
    y_train = data["preprocessed_eeg_data"]

    ch_names = data["ch_names"]
    times = data["times"]

    # Average across repetitions
    y_train = np.mean(y_train, 1)
    y_train = y_train[:, :, :t]
    y_train = np.clip(y_train, -clamp, clamp)

    ### Load the EEG test data ###
    test_file = "preprocessed_eeg_test.npy"
    data = np.load(os.path.join(base_dir, test_file), allow_pickle=True).item()
    y_test = data["preprocessed_eeg_data"]
    # Average across repetitions
    y_test = np.mean(y_test, 1)
    y_test = y_test[:, :, :t]
    y_test = np.clip(y_test, -clamp, clamp)

    return y_train, y_test, ch_names, times


def get_image_list(data_dir):
    img_dirs = os.path.join(data_dir, "training_images")
    train_img_list = []
    for root, dirs, files in os.walk(img_dirs):
        for file in files:
            if file.endswith(".jpg"):
                train_img_list.append(os.path.join(root, file))
    train_img_list.sort()

    img_dirs = os.path.join(data_dir, "test_images")
    test_img_list = []
    for root, dirs, files in os.walk(img_dirs):
        for file in files:
            if file.endswith(".jpg"):
                test_img_list.append(os.path.join(root, file))
    test_img_list.sort()

    return train_img_list, test_img_list


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=90, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


# %%
if __name__ == "__main__":
    args = get_args()

    # process images
    for i, sub in enumerate(OLD_NAMES):

        if i == 0:
            data_dir = os.path.join(args.orig_dir, "image")
            train_img_list, test_img_list = get_image_list(data_dir)
            N1, N2 = len(train_img_list), len(test_img_list)

            np.random.seed(args.seed)
            train_idx_shuffle = np.arange(len(train_img_list))
            np.random.shuffle(train_idx_shuffle)

            # save to new image dir
            save_dir = os.path.join(args.save_dir, NEW_NAMES[i])
            train_img_list = renew_list(
                args, os.path.join(save_dir, "image/train"), train_img_list
            )
            test_img_list = renew_list(
                args, os.path.join(save_dir, "image/test"), test_img_list
            )

            train_img_list = [train_img_list[i] for i in train_idx_shuffle]
            test_img_list = test_img_list
            val1_image_list = train_img_list[
                : int(len(train_img_list) * args.val1_split)
            ]
            train_img_list = [i for i in train_img_list if i not in val1_image_list]

        else:
            save_dir = os.path.join(args.save_dir, NEW_NAMES[i])
            os.makedirs(save_dir, exist_ok=True)

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(i + "\n")

        save_list(
            train_img_list, os.path.join(save_dir, preffix + "train_img_list.txt")
        )
        save_list(
            val1_image_list, os.path.join(save_dir, preffix + "val1_img_list.txt")
        )
        save_list(test_img_list, os.path.join(save_dir, preffix + "val2_img_list.txt"))

        # save to new y dir
        y_save_dir = os.path.join(save_dir, "eeg/train")
        skip_flag = False
        if os.path.exists(y_save_dir) and not args.overwrite:
            print("EEG data already exists, skipping...")
            skip_flag = True

        if not skip_flag:
            data_dir = os.path.join(args.orig_dir, "eeg", sub)
            train_y, test_y, ch_names, times = load_eeg_data(data_dir)

            mean, std = train_y.mean(), train_y.std()
            train_y = (train_y - mean) / std
            test_y = (test_y - mean) / std

        os.makedirs(y_save_dir, exist_ok=True)
        train_y_list = []
        for i in tqdm(range(N1), desc="Saving EEG data to: " + y_save_dir):
            path = os.path.join(y_save_dir, f"{i:010d}.npy")
            if not skip_flag:
                np.save(path, train_y[i].astype(np.float16))
            train_y_list.append(path)

        y_save_dir = os.path.join(save_dir, "eeg/test")
        os.makedirs(y_save_dir, exist_ok=True)
        test_y_list = []
        for i in tqdm(range(N2), desc="Saving EEG data to: " + y_save_dir):
            path = os.path.join(y_save_dir, f"{i:010d}.npy")
            if not skip_flag:
                np.save(path, test_y[i].astype(np.float16))
            test_y_list.append(path)

        train_y_list = [train_y_list[i] for i in train_idx_shuffle]
        test_y_list = test_y_list
        val1_y_list = train_y_list[: int(len(train_y_list) * args.val1_split)]
        train_y_list = [i for i in train_y_list if i not in val1_y_list]

        save_list(train_y_list, os.path.join(save_dir, preffix + "train_y_list.txt"))
        save_list(val1_y_list, os.path.join(save_dir, preffix + "val1_y_list.txt"))
        save_list(test_y_list, os.path.join(save_dir, preffix + "val2_y_list.txt"))

# %%
