import os
import numpy as np
import torch
from typing import Callable
import einops
import os
from tqdm import tqdm
import torch
from PIL import Image
from ....utils.json_stuff import load_json_as_dict


class MurtyNSD:
    def __init__(
        self,
        brain_signals_filename: str,
        image_data_filename: str,
        transforms: Callable = None,
    ):
        assert os.path.exists(
            brain_signals_filename
        ), f"Expected file to exist: {brain_signals_filename}"
        assert os.path.exists(
            image_data_filename
        ), f"Expected file to exist: {image_data_filename}"

        ## items, brain_signals
        brain_signals = np.load(brain_signals_filename).T

        ## items, h, w, c
        image_data = np.load(image_data_filename)
        self.brain_signals = torch.tensor(brain_signals)
        """
        pixel values are from 0 to 255 - np.uint8
        """
        self.image_data = image_data
        self.image_data = einops.rearrange(self.image_data, "b h w c -> b c h w")

        self.transforms = transforms

    def __getitem__(self, idx):

        image_numpy = self.image_data[idx]
        image_numpy_hwc = np.moveaxis(image_numpy, 0, -1)
        image = Image.fromarray(image_numpy_hwc)

        if self.transforms is not None:
            image = self.transforms(image)
        else:
            pass
        return {
            "brain_response": self.brain_signals[idx],
            "image_tensor": image,
        }

    def __len__(self):
        return self.image_data.shape[0]


class NaturalScenesImageEncodingDataset:
    def __init__(
        self, image_filenames_and_labels_folder: str, image_encodings_folder: str
    ):
        """Dataset with pre-computed image encodings as X and the original labels as Y.

        This was originally made with the Murty185 dataset in mind. But might also make it compatible with other datasets when needed.

        Args:
            image_filenames_and_labels_folder (str): Path to the folder containing the json files which you used when you built a dataset with `brain_candy.datasets.image_encoding_dataset.ImageEncodingDatasetBuilder`
            image_encodings_folder (str): Path to the folder containing the pth files which you used when you built a dataset with `brain_candy.datasets.image_encoding_dataset.ImageEncodingDatasetBuilder`
        """
        assert os.path.exists(
            image_filenames_and_labels_folder
        ), f"{image_filenames_and_labels_folder}"

        num_files = len(os.listdir(image_encodings_folder))

        self.image_encodings = []
        self.labels = []

        for i in tqdm(range(num_files), desc="Loading image encodings and labels"):
            data = load_json_as_dict(
                filename=os.path.join(image_filenames_and_labels_folder, f"{i}.json")
            )

            assert os.path.exists(data["image_encoding_filename"])
            self.image_encodings.append(torch.load(data["image_encoding_filename"]))
            self.labels.append(data["brain_response"])

    def __getitem__(self, idx):
        """Obtains a single instance of the dataset

        Args:
            idx (int): index of the item to be obtained

        Returns:
            dict: contains the image encoding and the label as "brain_response"
        """
        assert self.image_encodings[idx].shape[0] == 1, "Expected batch size to be 1"
        return {
            "image_encoding": self.image_encodings[idx].squeeze(0),  ## remove batch dim
            "brain_response": torch.tensor(self.labels[idx]),
        }

    def __len__(self):
        return len(self.labels)
