# %%
import os

from abc import ABC, abstractmethod
import random

import numpy as np
import torch
from torch import Tensor
from typing import Any, Tuple, List, Dict, Union
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from PIL import Image
from torchvision import transforms

import logging

IMAGE_TRAMSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


class Data(Dict):
    """_summary_
    A dict-like object that stores all the information of a datapoint
    works for autocomplete
    works for printing
    """

    def __init__(self, **kwds):
        self.subject_id = None
        self.img = None
        self.y = None
        self.session_id = None
        self.eye_coords = None
        self.__dict__.update(kwds)

    def keys(self):
        return self.__dict__.keys()

    def values(self):
        return self.__dict__.values()

    def items(self):
        return self.__dict__.items()

    def __getitem__(self, key):
        return self.__dict__[key]

    def __setitem__(self, key, value):
        self.__dict__[key] = value

    def __repr__(self) -> str:
        s = "Data: \t"
        for k, v in self.__dict__.items():
            if isinstance(v, Tensor):
                s += f"{k}: {list(v.size())}, \t"
            if isinstance(v, list):
                s += f"{k}: {len(v)}, \t"
            if isinstance(v, str):
                s += f"{k}: {v}, \t"
        return s


class VWEOnDiskDataset(Dataset):
    """_summary_
    Dataset for voxel-wise encoding, each datapoint is a image-brain pair
    One Dataset for one subject, multiple subjects are combined in a datamodule
    One Dataset can contain multiple scan sessions
    Each voxel is associated with a physical location in the brain, for MEEG time series,
    the location is same for all time points of the same sensor
    """

    def __init__(
        self,
        image_paths: List[str],  # length is B
        subject_id: str,
        neuron_coords: Tensor = None,  # [n, 3], n is the number of voxels for fMRI, n = t x c for MEEG
        eye_coords: Tensor = None,  # [B, 2]
        eye_coords_transform: Any = None,  # transform for eye_coords, default is None
        session_ids: List[str] = None,  # length is B
        y_paths: List[str] = None,  # None for no y (e.g. NSD test set)
        noise_ceiling: Tensor = None,  # noise ceiling for each voxel
        time_series_length: int = 1,  # for MEEG, time series length (100), for fMRI is 1
        resolution: Tuple[int] = (224, 224),  # load image and resize to this resolution
        image_transform=IMAGE_TRAMSFORM,  # normalize
        image_transform2=None,  # transform after resize, maybe padding
        y_transform=None,
        datapoint_idxs: Tensor = ...,  # indices of datapoints to use, for filtering out bad data
        voxel_index: Tensor = ...,  # indices of y to use, for sub cluster of voxels
        roi_dict: Dict[str, Tensor] = None,  # for evaluating on ROIs
        meta_info_dict=None,  # for storing additional info, not in used
        feature_extractor_mode=False,  # for slow feature extraction backbones (e.g. SNN)
        img_fmt="JPEG",
        video_frames=10,
        random_frames=False,
        clamp_value=20,
        dark_postfix = "",
    ) -> None:
        super().__init__()

        self.image_paths = image_paths
        self.y_paths = y_paths
        if self.y_paths is not None:
            if len(self.y_paths) == 0:
                self.y_paths = None
        if self.y_paths is not None:
            assert len(self.image_paths) == len(
                self.y_paths
            ), f"image and y should have same length subject_id: {subject_id}"
        self.subject_id = subject_id
        self.session_ids = session_ids
        self.neuron_coords = neuron_coords
        self.eye_coords = eye_coords
        self.resolution = resolution
        self.image_transform = image_transform
        self.image_transform2 = image_transform2
        self.eye_coords_transform = eye_coords_transform
        self.time_series_length = time_series_length
        self.img_fmt = img_fmt
        self.video_frames = video_frames
        self.random_frames = random_frames
        self.clamp_value = clamp_value
        self.noise_ceiling = noise_ceiling
        self.dark_postfix = dark_postfix

        self.y_transform = y_transform
        self.datapoint_idxs = (
            torch.arange(len(self.image_paths))
            if datapoint_idxs is ...
            else datapoint_idxs
        )
        self.voxel_index = voxel_index
        self.roi_dict = roi_dict
        self.meta_info_dict = meta_info_dict

        self.resize = transforms.Resize(self.resolution)

        if self.neuron_coords is None:
            assert self.y_paths is not None
            if self.voxel_index == ...:
                # if subject_id == "EEG":
                #     self.neuron_coords = self.make_dummy_neuron_coords(
                #         self.num_voxels, 1
                #     )
                #     # this is a bug, should be the else way, but the model is already trained for EEG
                # else:
                self.neuron_coords = self.make_dummy_neuron_coords(
                    int(self.num_voxels / time_series_length), time_series_length
                )
            else:
                if subject_id == "EEG":
                    n = 170
                    self.neuron_coords = self.make_dummy_neuron_coords(
                        n, time_series_length
                    )
                else:
                    raise NotImplementedError
        self.neuron_coords = self.neuron_coords[self.voxel_index]
        self.noise_ceiling = (
            self.noise_ceiling[self.voxel_index]
            if self.noise_ceiling is not None
            else None
        )

        # for slow feature extraction backbones (e.g. SNN), freeze backbone
        self.feature_extractor_mode = feature_extractor_mode
        if self.feature_extractor_mode:
            self.image_transform = None
            self.image_transform2 = None
            self.resize = None

            def _load_feature(self, path) -> Tensor:
                x = np.load(path)
                x = torch.from_numpy(x, dtype=torch.float32)
                return x

            setattr(self, "_load_image", _load_feature)

    def __repr__(self) -> str:
        return (
            super().__repr__()
            + "\n"
            + f"subject_id: {self.subject_id}, "
            + f"time_series_length: {self.time_series_length}, "
            + f"data length: {len(self)}, "
            + f"num_voxels: {self.num_voxels}, "
            + f"neuron_coords: {self.neuron_coords.shape if self.neuron_coords is not None else None}, "
            + f"eye_coords: {self.eye_coords.shape if self.eye_coords is not None else None}, "
        )

    def __len__(self):
        return len(self.datapoint_idxs)

    @property
    def num_voxels(self):
        if self.y_paths is None:
            logging.error("y_paths is None, num_voxels is unknown")
            return None
        tup = self.__getitem__(0)
        y = tup[1]
        n = y.shape[0]
        return n

    def load_image(self, path) -> Tensor:
        """
        API for calling from outside
        """
        img = self._load_image(path)
        if self.image_transform is not None:
            img = self.image_transform(img)
        if self.image_transform2 is not None:
            img = self.image_transform2(img)
        return img

    @staticmethod
    def make_dummy_neuron_coords(n, t=1):
        neuron_coords = torch.zeros(n, t)
        for i in range(n):
            neuron_coords[i, :] = i / t
        neuron_coords /= neuron_coords.max()
        neuron_coords = neuron_coords.reshape(-1, 1)
        return neuron_coords

    def _getitem_y(self, index):
        y = None
        if self.y_paths is not None:
            y = self._get_y(index)
            y = y.flatten()
            y = y[self.voxel_index]
            y = y.float()
            if self.y_transform:
                y = self.y_transform(y)
        return y
    
    def _getitem_darkness(self, index):
        if self.dark_postfix is None or self.dark_postfix == "" or self.y_paths is None:
            return None
        y = None
        path = self.y_paths[self.datapoint_idxs[index]]
        path = path.replace(".npy", f"{self.dark_postfix}.npy")
        if not os.path.exists(path):
            return None
        y = self._load_y(path) # tensor
        y = y.flatten()
        y = y[self.voxel_index]
        y = y.float()
        return y
        
    def __getitem__(self, index, skip_img=False) -> Data:
        if not skip_img:
            img = self._get_image(index)
            if self.image_transform is not None:  # 0-255 -> 0-1, normalize
                img = self.image_transform(img)
            if self.image_transform2 is not None:  # padding, resize
                img = self.image_transform2(img)
        else:
            img = None

        y = self._getitem_y(index)
        darkness = self._getitem_darkness(index)

        subject_id = self.subject_id
        session_id = ""
        if self.session_ids is not None:
            session_id = self.session_ids[self.datapoint_idxs[index]]
        eye_coords = None  # dummy
        if self.eye_coords is not None:
            eye_coords = self.eye_coords[self.datapoint_idxs[index]]

        # data = Data(
        #     img=img,
        #     y=y,
        #     subject_id=subject_id,
        #     session_id=session_id,
        #     eye_coords=eye_coords,
        # )

        return img, y, subject_id, session_id, eye_coords, darkness

    @staticmethod
    def collate_fn(batch: List[Data]) -> Data:
        img, y, subject_id, session_id, eye_coords, darkness = zip(*batch)
        # new_dict = {k: [] for k in datas[0].keys()}
        # for k in datas[0].keys():
        #     for data in datas:
        #         new_dict[k].append(data[k])
        # for k in new_dict.keys():
        #     if k == "img":
        #         new_dict[k] = torch.stack(new_dict[k], dim=0)
        #     elif k == "y":
        #         pass  # leave as list
        #     elif k == "session_id":
        #         new_dict[k] = np.asarray(new_dict[k])
        #     elif k == "subject_id":
        #         new_dict[k] = np.asarray(new_dict[k])
        #     elif k == "eye_coords":
        #         if new_dict[k][0] is None:
        #             new_dict[k] = None
        #         else:
        #             new_dict[k] = torch.stack(new_dict[k], dim=0)
        #     else:
        #         raise ValueError(f"unknown key {k}")

        # data = Data(**new_dict)
        # img = data["img"]
        # y = data["y"]
        # subject_ids = data["subject_id"]
        # sess_ids = data["session_id"]
        # eye_coords = data["eye_coords"]
        img = torch.stack(img, dim=0)
        y = y  # leave as list
        darkness = darkness
        subject_ids = np.asarray(subject_id)
        session_ids = np.asarray(session_id)
        eye_coords = (
            torch.stack(eye_coords, dim=0) if eye_coords[0] is not None else None
        )
        return img, y, subject_ids, session_ids, eye_coords, darkness

    def _load_image(self, path) -> Image:
        if not path.endswith(self.img_fmt):
            # load random frame from video
            if self.random_frames:
                i = random.randint(0, self.video_frames - 1)
            else:
                i = 0
            path = os.path.join(path, f"{i}.{self.img_fmt}")
        img = Image.open(path)
        img = img.convert("RGB")
        img = self.resize(img)  # only resize
        return img

    def _load_y(self, path) -> Tensor:
        if not os.path.exists(path):
            raise FileNotFoundError(f"y file {path} not found")
        y = np.load(path)
        y = torch.from_numpy(y)
        y = y.float()
        y = torch.clamp(y, -self.clamp_value, self.clamp_value)
        return y

    def _get_image(self, index):
        path = self.image_paths[self.datapoint_idxs[index]]
        img = self._load_image(path)
        return img

    def _get_y(self, index):
        path = self.y_paths[self.datapoint_idxs[index]]
        y = self._load_y(path)
        return y


class VWEInMemoryDecodedDataset(VWEOnDiskDataset):
    # deprecated
    def __init__(
        self,
        *args,
        cache_decoded_images=False,
        # cache_decoded_ys=True,
        cache_dir="/data/cache",
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)

        # cache into files
        import hashlib

        s = ";".join(self.image_paths) + str(self.resolution)
        if cache_decoded_images:
            self.image_hash = hashlib.md5(s.encode()).hexdigest()
            self.image_cache_path = os.path.join(cache_dir, self.image_hash + ".pt")
            if os.path.exists(self.image_cache_path):
                self._X = torch.load(self.image_cache_path)
            else:
                self._X = self._load_images()
                torch.save(self._X, self.image_cache_path)
        else:
            self._X = self._load_images()

        self._y = self._load_ys()

    def _get_image(self, index):
        return self._X[self.datapoint_idxs[index]]

    def _get_y(self, index):
        return self._y[self.datapoint_idxs[index]]

    def _load_images(self):
        X = []
        for i in tqdm(range(len(self.image_paths)), desc="Loading images"):
            image = self._load_image(self.image_paths[i])
            X.append(image)
        return X

    def _load_ys(self):
        ys = []
        if self.y_paths is None:
            return None
        for i in tqdm(range(len(self.y_paths)), desc="Loading ys"):
            y = self._load_y(self.y_paths[i])
            ys.append(y)
        return ys
