# %%
from functools import partial
import logging
import os
import sys
from copy import copy
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
from torch import nn
import torch
from torchvision import transforms
from einops import rearrange, reduce
from torch.utils.data import ConcatDataset, DataLoader

from config import AutoConfig

from datasets import VWEOnDiskDataset
from point_pe import point_position_encoding

from registry import Registry


DATAMODULE = Registry()


@DATAMODULE.register("ALL")
class AllDatamodule(pl.LightningDataModule):
    def __init__(
        self,
        cfg: AutoConfig,
    ):
        super().__init__()
        self.cfg = cfg
        self.dss = [{}, {}, {}, {}]  # train, val1, val2, test
        self.stage_list = ["train", "val1", "val2", "test"]
        self.ds_dict = {}  # pass by reference
        for i, stage in enumerate(self.stage_list):
            self.ds_dict[stage] = self.dss[i]

        self.train_batch_size = self.cfg.DATAMODULE.BATCH_SIZE
        self.eval_batch_size = int(
            np.ceil(
                self.cfg.DATAMODULE.BATCH_SIZE
                * self.cfg.DATAMODULE.EVAL_BATCH_SIZE_MULTIPLIER
            )
        )

        self.subject_list = self.get_subject_list(self.cfg.DATASET.SUBJECT_LIST)

        # apply padding to the image if needed
        self.padding = nn.ZeroPad2d(self.cfg.DATASET.PADDING)

    @property
    def num_voxel_dict(self):
        ret = {}

        ds_dict = self.dss[0]
        for name, ds in ds_dict.items():
            ret[name] = ds.num_voxels

        return ret

    @property
    def roi_dict(self):
        ret = {}

        ds_dict = self.dss[0]
        for name, ds in ds_dict.items():
            ret[name] = ds.roi_dict

        return ret

    @property
    def neuron_coords_dict(self):
        ret = {}

        ds_dict = self.dss[0]
        for name, ds in ds_dict.items():
            ret[name] = ds.neuron_coords

        return ret

    @property
    def noise_ceiling_dict(self):
        ret = {}

        ds_dict = self.dss[0]
        for name, ds in ds_dict.items():
            ret[name] = ds.noise_ceiling

        return ret

    @property
    def collate_fn(self):
        return list(self.dss[0].values())[0].collate_fn

    def train_dataloader(self, subject=None, shuffle=True):
        idx = 0
        if subject is None:
            ds = ConcatDataset(list(self.dss[idx].values()))
        else:
            ds = self.dss[idx][subject]
        return DataLoader(
            ds,
            batch_size=self.train_batch_size,
            shuffle=shuffle,
            collate_fn=self.collate_fn,
            num_workers=self.cfg.DATAMODULE.NUM_WORKERS,
            pin_memory=self.cfg.DATAMODULE.PIN_MEMORY,
        )

    def val_dataloader(self, subject=None):
        idx = 1
        if subject is None:
            ds = ConcatDataset(list(self.dss[idx].values()))
        else:
            ds = self.dss[idx][subject]
        return DataLoader(
            ds,
            batch_size=self.eval_batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
            num_workers=self.cfg.DATAMODULE.NUM_WORKERS,
            pin_memory=self.cfg.DATAMODULE.PIN_MEMORY,
        )

    def test_dataloader(self, subject=None):
        idx = 2
        if subject is None:
            ds = ConcatDataset(list(self.dss[idx].values()))
        else:
            ds = self.dss[idx][subject]
        return DataLoader(
            ds,
            batch_size=self.eval_batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
            num_workers=self.cfg.DATAMODULE.NUM_WORKERS,
            pin_memory=self.cfg.DATAMODULE.PIN_MEMORY,
        )

    def predict_dataloader(self, subject=None):
        idx = 3
        if self.dss[idx] == {}:
            return None
        if subject is None:
            ds = ConcatDataset(list(self.dss[idx].values()))
        else:
            ds = self.dss[idx][subject]
        return DataLoader(
            ds,
            batch_size=self.eval_batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
            num_workers=self.cfg.DATAMODULE.NUM_WORKERS,
            pin_memory=self.cfg.DATAMODULE.PIN_MEMORY,
        )

    def teardown(self, stage: Optional[str] = None):
        # Used to clean-up when the run is finished
        # self.dss = {}, {}, {}, {}  # train, test, val, predict
        pass

    @staticmethod
    def get_all_subject_list(root_dir):
        ret = []
        for p in os.listdir(root_dir):
            if p[0] == ".":
                continue
            if not os.path.isdir(os.path.join(root_dir, p)):
                continue
            if len(os.listdir(os.path.join(root_dir, p))) == 0:
                continue
            ret.append(p)
        ret.sort()
        return ret

    def get_subject_list(self, cfg_subject_list):
        subject_list = self.cfg.DATASET.SUBJECT_LIST

        all_subject_list = self.get_all_subject_list(self.cfg.DATASET.ROOT)

        if isinstance(cfg_subject_list, str):
            cfg_subject_list = [cfg_subject_list]

        # all in
        if cfg_subject_list[0].lower() == "all":
            subject_list = all_subject_list
        # all but sub set
        all_flag = False
        for subject_id in cfg_subject_list:
            if subject_id.lower().startswith("all_"):
                all_flag = True
        if all_flag:
            group_list = []
            for subject_id in cfg_subject_list:
                if not subject_id.lower().startswith("all_"):
                    logging.error(
                        f"subject_id {subject_id} is not start with 'all_' (case insensitive), skippping..."
                    )
                    continue
                group_list.append(subject_id[4:])
            filter_subject_list = []
            for subject_id in all_subject_list:
                for group in group_list:
                    if subject_id.lower().startswith(group.lower()):
                        filter_subject_list.append(subject_id)
            subject_list = filter_subject_list

        for sub in subject_list:
            assert sub in all_subject_list, f"{sub} not in {all_subject_list}"
        return subject_list

    def setup(self, stage: Optional[str] = None, overwrite: bool = False):

        r1, r2 = self.cfg.DATASET.RESOLUTION
        # r1, r2 = 512, 512
        if r1 == 96 and r2 == 96:
            pass
        else:
            r1, r2 = 288, 288
        fmt = self.cfg.DATASET.IMAGE_FMT

        stages = self.stage_list

        preffix = f"{r1}x{r2}-{fmt}-"

        for subject_id in self.subject_list:
            for stage in stages:
                idx = self.stage_list.index(stage)

                image_list_path = os.path.join(
                    self.cfg.DATASET.ROOT, subject_id, f"{preffix}{stage}_img_list.txt"
                )
                if not os.path.exists(image_list_path):
                    if stage != "test":
                        logging.warning(f"{image_list_path} not exist")
                    continue

                def load_list(path):
                    with open(path, "r") as f:
                        return [line.strip() for line in f.readlines()]

                image_list = load_list(image_list_path)

                y_list = None  # missing y is allowed (for test set)
                y_list_dir = os.path.join(self.cfg.DATASET.ROOT, subject_id)
                files = os.listdir(y_list_dir)
                y_list_paths = [p for p in files if p.endswith(f"{stage}_y_list.txt")]
                if len(y_list_paths) > 1:
                    logging.warning(f"{y_list_dir}/{stage}_y_list.txt more than one")
                    y_list_paths = [p for p in y_list_paths if p.startswith(f"{r1}x{r2}")]
                if len(y_list_paths) == 0:
                    if stage != "test":
                        logging.warning(f"{y_list_dir}/{stage}_y_list.txt not exist")
                else:
                    y_list_path = os.path.join(y_list_dir, y_list_paths[0])
                    y_list = load_list(y_list_path)

                exist_flag = subject_id in self.dss[idx]
                if exist_flag and not overwrite:
                    continue

                neuron_coords = None
                neuron_coords_path = os.path.join(
                    self.cfg.DATASET.ROOT, subject_id, f"neuron_coords.npy"
                )
                if os.path.exists(neuron_coords_path):
                    neuron_coords = np.load(neuron_coords_path)
                    neuron_coords = torch.from_numpy(neuron_coords).float()

                if "EEG" in subject_id or "MEG" in subject_id or "ECoG" in subject_id:
                    time_series_length = self.cfg.DATASET.TIME_SERIES_LENGTH
                else:
                    time_series_length = 1

                session_ids = None
                session_ids_path = os.path.join(
                    self.cfg.DATASET.ROOT,
                    subject_id,
                    f"{preffix}{stage}_session_ids.txt",
                )
                if os.path.exists(session_ids_path):
                    session_ids = load_list(session_ids_path)

                noise_ceiling = None
                noise_ceiling_path = os.path.join(
                    self.cfg.DATASET.ROOT, subject_id, "nc.npy"
                )
                if os.path.exists(noise_ceiling_path):
                    noise_ceiling = np.load(noise_ceiling_path)
                    noise_ceiling = torch.from_numpy(noise_ceiling).float()

                roi_dict = {"all": ...}
                roi_dir = os.path.join(self.cfg.DATASET.ROOT, subject_id, "roi")
                if os.path.exists(roi_dir):
                    for roi_path in os.listdir(roi_dir):
                        roi_name = roi_path.split(".")[0]
                        if roi_name not in ["early", "mid", "late", "all"]:
                            continue
                        roi = np.load(os.path.join(roi_dir, roi_path))
                        roi_dict[roi_name] = roi

                voxel_index = ...
                if self.cfg.DATASET.ROIS[0] != "all":
                    new_roi_dict = {}
                    new_voxel_index = []
                    l = 0
                    for roi in self.cfg.DATASET.ROIS:
                        roi_path = os.path.join(roi_dir, f"{roi}.npy")
                        if not os.path.exists(roi_path):
                            continue
                        vi = np.load(roi_path)
                        new_voxel_index.append(vi)
                        new_roi_dict[roi] = np.arange(len(vi)) + l
                        l += len(vi)
                    if l == 0:
                        continue
                    # new_roi_dict["all"] = np.arange(l)
                    new_roi_dict["all"] = ...
                    new_voxel_index = np.concatenate(new_voxel_index)
                    new_voxel_index = torch.from_numpy(new_voxel_index).long()
                    voxel_index = new_voxel_index
                    roi_dict = new_roi_dict

                if voxel_index != ... and len(voxel_index) < 20:
                    # my grandma runs faster than this code
                    continue

                eye_coords = None
                path = os.path.join(
                    self.cfg.DATASET.ROOT, subject_id, f"eye_coords_{stage}.npy"
                )
                if os.path.exists(path):
                    eye_coords = np.load(path)
                    eye_coords = torch.from_numpy(eye_coords).float()

                if subject_id == 'ALG' and stage == 'train':
                    # ALG training set is too small compared to other datasets, so we duplicate it
                    image_list = sum([image_list for _ in range(3)], [])
                    y_list = sum([y_list for _ in range(3)], [])
                
                ds = VWEOnDiskDataset(
                    image_paths=image_list,
                    y_paths=y_list,
                    voxel_index=voxel_index,
                    roi_dict=roi_dict,
                    noise_ceiling=noise_ceiling,
                    neuron_coords=neuron_coords,
                    eye_coords=eye_coords,
                    subject_id=subject_id,
                    session_ids=session_ids,
                    resolution=self.cfg.DATASET.RESOLUTION,
                    image_transform2=self.padding,
                    time_series_length=time_series_length,
                    feature_extractor_mode=self.cfg.DATAMODULE.FEATURE_EXTRACTOR_MODE,
                    img_fmt=self.cfg.DATASET.IMAGE_FMT,
                    video_frames=self.cfg.DATASET.VIDEO_FRAMES,
                    random_frames=self.cfg.DATASET.RANDOM_FRAMES
                    if stage == "train"
                    else False,
                    clamp_value=self.cfg.DATASET.CLAMP_VALUE,
                    dark_postfix=self.cfg.DATASET.DARK_POSTFIX,
                )

                # positional encoding
                neuron_coords = ds.neuron_coords
                neuron_coords = point_position_encoding(
                    points=neuron_coords,
                    max_steps=self.cfg.POSITION_ENCODING.MAX_STEPS,
                    features=self.cfg.POSITION_ENCODING.FEATURES,
                    periods=self.cfg.POSITION_ENCODING.PERIODS,
                )
                ds.neuron_coords = neuron_coords

                self.dss[idx].update({subject_id: ds})

        # print(self)
        return

    def __repr__(self):
        s = "DataModule: \n"
        for stage in self.stage_list:
            idx = self.stage_list.index(stage)
            num_datas = sum([len(self.dss[idx][sub]) for sub in self.dss[idx]])
            s += f"  {stage}: {num_datas:,} datas, {len(self.dss[idx])} subjects\n"
        return s


def build_dm(cfg: AutoConfig):
    dm = DATAMODULE[cfg.DATASET.NAME](cfg)
    return dm


# %%
if __name__ == "__main__":
    from config_utils import get_cfg_defaults

    cfg = get_cfg_defaults()
    sub = "all"
    cfg.DATASET.SUBJECT_LIST = [sub]
    # cfg.DATASET.SUBJECT_LIST = "ALL"
    cfg.DATASET.NAME = "ALL"
    cfg.DATAMODULE.NUM_WORKERS = 0

    dm = build_dm(cfg)
    dm.setup()
# %%
if __name__ == "__main__":
    cfg.DATASET.SUBJECT_LIST = ["NSD_01"]
    cfg.DATASET.ROIS = ["early"]
    dm = build_dm(cfg)
    dm.setup()
# %%
