# %%
import argparse
import os
from pathlib import Path

import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
from PIL import Image
from scipy.stats import pearsonr as corr
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
import mne
from tqdm import tqdm

from PIL import Image

from argparse import ArgumentParser

from pathlib import Path

OLD_NAMES = [i for i in range(1, 5)]
NEW_NAMES = [f"MEG"]


def get_args(s=None):
    parser = ArgumentParser()
    parser.add_argument(
        "--orig_dir", type=str, default="/data/things/meg1/THINGS-MEG_preprocessed"
    )
    parser.add_argument("--orig_image_dir", type=str, default="/data/things/all_images")
    parser.add_argument("--clamp", type=float, default=20.0)
    parser.add_argument("--save_dir", type=str, default="/data/VWET")
    parser.add_argument("--resolution", type=int, default=224)
    parser.add_argument("--save_fmt", type=str, default="JPEG")
    parser.add_argument("--val1_split", type=float, default=0.01)
    # parser.add_argument("--val2_split", type=float, default=0.0) # use test set
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--overwrite", action="store_true")
    # save unrecognized args
    # parser.add_argument("--", dest="unparsed", nargs=argparse.REMAINDER)
    # parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args(s)
    return args


def load_meg_data(epochs, clamp=20, t=100):

    epochs = epochs.resample(100, npad="auto", n_jobs=16, verbose=True)

    # train val
    evoked = epochs[(epochs.metadata["trial_type"] == "exp")]
    y_train = evoked._data
    y_train = y_train[:, :, :t]
    y_train = np.clip(y_train, -clamp, clamp)
    # y_train = y_train.astype(np.float16)

    # test
    y_test = []
    evoked = epochs[(epochs.metadata["trial_type"] == "test")]
    for p in tqdm(evoked.metadata.image_path.unique()):
        i_evoked = epochs[epochs.metadata["image_path"] == p]
        y_test.append(i_evoked.average()._data)
    y_test = np.stack(y_test, 0)
    y_test = y_test[:, :, :t]
    y_test = np.clip(y_test, -clamp, clamp)
    # y_test = y_test.astype(np.float16)

    ch_names, times = epochs.ch_names, epochs.times
    return y_train, y_test, ch_names, times


def get_image_list(epochs, orig_image_dir):
    # train val
    evoked = epochs[(epochs.metadata["trial_type"] == "exp")]
    train_img_list = evoked.metadata["image_path"].tolist()
    train_img_list = [
        os.path.join(orig_image_dir, os.path.basename(i)) for i in train_img_list
    ]

    # test
    evoked = epochs[(epochs.metadata["trial_type"] == "test")]
    test_img_list = evoked.metadata["image_path"].unique().tolist()
    test_img_list = [
        os.path.join(orig_image_dir, os.path.basename(i)) for i in test_img_list
    ]

    return train_img_list, test_img_list


def resave_image(
    orig_path, save_path, resolution, fmt="JPEG", quality=100, overwrite=False
):
    if os.path.exists(save_path) and not overwrite:
        return
    img = Image.open(orig_path).convert("RGB")
    img = img.resize((resolution, resolution))
    img.save(save_path, fmt, quality=quality)


def renew_list(args, save_dir, img_list):
    os.makedirs(save_dir, exist_ok=True)
    new_img_list = []
    for j, img_path in tqdm(
        enumerate(img_list), desc="Resaving images to: " + save_dir, total=len(img_list)
    ):
        base = os.path.basename(img_path).split(".")[0]
        reso = f"{args.resolution}x{args.resolution}"
        base = f"{base}_{reso}.{args.save_fmt}"
        save_path = os.path.join(save_dir, base)
        resave_image(
            img_path,
            save_path,
            args.resolution,
            args.save_fmt,
            overwrite=args.overwrite,
        )
        new_img_list.append(save_path)
    return new_img_list


# %%
if __name__ == "__main__":
    args = get_args()

    all_train_y, all_test_y = [], []
    coords_max = 0
    all_coords = []
    for i, sub in enumerate(OLD_NAMES[:]):
        base_dir = args.orig_dir
        p = os.path.join(base_dir, f"epoched_data_all_P{i+1}_downsampled.fif")
        epochs = mne.read_epochs(
            p,
            preload=False,
            verbose=False,
        )
        epochs = epochs[np.argsort(epochs.metadata.image_path)]

        names = ["O", "T", "P"]
        picks_epochs = [
            epochs.ch_names[i]
            for i in np.where([s[2] in names for s in epochs.ch_names])[0]
        ]
        epochs.load_data()
        epochs = epochs.pick_channels(ch_names=picks_epochs)

        if i == 0:

            train_img_list, test_img_list = get_image_list(epochs, args.orig_image_dir)
            N1, N2 = len(train_img_list), len(test_img_list)

            np.random.seed(args.seed)
            train_idx_shuffle = np.arange(len(train_img_list))
            np.random.shuffle(train_idx_shuffle)

            # save to new image dir
            save_dir = os.path.join(args.save_dir, NEW_NAMES[0])
            train_img_list = renew_list(
                args, os.path.join(save_dir, "image/train"), train_img_list
            )
            test_img_list = renew_list(
                args, os.path.join(save_dir, "image/test"), test_img_list
            )

            train_img_list = [train_img_list[i] for i in train_idx_shuffle]
            test_img_list = test_img_list
            val1_image_list = train_img_list[
                : int(len(train_img_list) * args.val1_split)
            ]
            train_img_list = [i for i in train_img_list if i not in val1_image_list]
        else:
            save_dir = os.path.join(args.save_dir, NEW_NAMES[0])
            os.makedirs(save_dir, exist_ok=True)

        preffix = f"{args.resolution}x{args.resolution}-{args.save_fmt}-"

        def save_list(l, name):
            with open(os.path.join(save_dir, name), "w") as f:
                for i in l:
                    f.write(i + "\n")

        save_list(
            train_img_list, os.path.join(save_dir, preffix + "train_img_list.txt")
        )
        save_list(
            val1_image_list, os.path.join(save_dir, preffix + "val1_img_list.txt")
        )
        save_list(test_img_list, os.path.join(save_dir, preffix + "val2_img_list.txt"))


        train_y, test_y, ch_names, times = load_meg_data(epochs, args.clamp)

        mean, std = train_y.mean(), train_y.std()
        print("before normalization: ", mean, std)

        train_y = (train_y - mean) / std
        test_y = (test_y - mean) / std

        all_train_y.append(train_y)
        all_test_y.append(test_y)
        
        new_mean, new_std = train_y.mean(), train_y.std()
        print("after normalization: ", new_mean, new_std)
        
        neuron_coords = np.stack([i["loc"] for i in epochs.info["chs"]])
        neuron_coords = neuron_coords[:, :3]

        from einops import repeat

        neuron_coords = repeat(neuron_coords, "n c -> (n t) c", t=100)

        print(neuron_coords.shape)
        
        coords_max += neuron_coords.max()
        neuron_coords += coords_max
        
        all_coords.append(neuron_coords)
    
    train_y = np.concatenate(all_train_y, axis=1)
    test_y = np.concatenate(all_test_y, axis=1)
    coords = np.concatenate(all_coords, axis=0)    
    
    skip_flag = False

    save_dir = os.path.join(args.save_dir, NEW_NAMES[0])
    # save to new y dir
    y_save_dir = os.path.join(save_dir, "meg/train")
    os.makedirs(y_save_dir, exist_ok=True)
    train_y_list = []
    for i in tqdm(range(N1), desc="Saving meg data to: " + y_save_dir):
        path = os.path.join(y_save_dir, f"{i:010d}.npy")
        if not skip_flag:
            np.save(path, train_y[i].astype(np.float16))
        train_y_list.append(path)

    y_save_dir = os.path.join(save_dir, "meg/test")
    os.makedirs(y_save_dir, exist_ok=True)
    test_y_list = []
    for i in tqdm(range(N2), desc="Saving meg data to: " + y_save_dir):
        path = os.path.join(y_save_dir, f"{i:010d}.npy")
        if not skip_flag:
            np.save(path, test_y[i].astype(np.float16))
        test_y_list.append(path)

    train_y_list = [train_y_list[i] for i in train_idx_shuffle]
    test_y_list = test_y_list
    val1_y_list = train_y_list[: int(len(train_y_list) * args.val1_split)]
    train_y_list = [i for i in train_y_list if i not in val1_y_list]

    save_list(train_y_list, os.path.join(save_dir, preffix + "train_y_list.txt"))
    save_list(val1_y_list, os.path.join(save_dir, preffix + "val1_y_list.txt"))
    save_list(test_y_list, os.path.join(save_dir, preffix + "val2_y_list.txt"))

    np.save(os.path.join(save_dir, "neuron_coords.npy"), coords)


# %%
