import os

import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


def minmax_scale(inputs: torch.Tensor):
    assert len(inputs.shape) == 3, f"got unexpected shape {inputs.shape}"
    min_values = inputs.amin(dim=(1, 2))
    return (inputs - min_values) / (inputs.amax(dim=(1, 2)) - min_values)


def get_transform(transform_name: str):
    if transform_name.lower() == "totensor":
        return transforms.ToTensor()
    elif transform_name.lower() == "[-1, 1]":
        return transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: minmax_scale(x) * 2 - 1),
            ]
        )

    print(f"Not Using Any Transform!")
    return None


# FIXME: deprecated, maybe delete in the future
class MNIST_AAE_Dataset(Dataset):
    def __init__(self, processed_data_path) -> None:
        super().__init__()
        self.processed_data_path = processed_data_path
        self.data = torch.load(processed_data_path)

    def __len__(self):
        return len(self.data["digits"])

    def __getitem__(self, index):
        encoder_params = self.data["encoder_params"][index]  # can be torch.nan
        if "pred_params" not in self.data:
            self.data["pred_params"] = [torch.nan] * self.__len__()

        return {
            "index": index,  # for setting the encoder_params after training
            "images": self.data["images"][index],
            "digits": self.data["digits"][index],
            "encoder_params": encoder_params,
            "pred_params": self.data["pred_params"][index],
        }

    def save_encoder_params(self, index, encoder: nn.Module):
        self.data["encoder_params"][index] = encoder.state_dict()

    def save_pred_params(self, index, pred_params: torch.Tensor):
        self.data["pred_params"][index] = pred_params

    def save_dataset_to_disk(self, path=None):
        if path is None:
            path = self.processed_data_path

        torch.save(self.data, path)


# TODO: maybe we need create a abstract base class for all dataset we use
class MNISTDataset(Dataset):
    def __init__(self, data_path) -> None:
        super().__init__()
        self.processed_data_path = data_path
        self.data = torch.load(data_path)

    def __len__(self):
        return len(self.data["digits"])

    def __getitem__(self, index):
        # TODO: in the future, the base class should enforce the format of returned data
        encoder_params = self.data["encoder_params"][index]  # can be torch.nan
        if "pred_params" not in self.data:
            self.data["pred_params"] = [torch.nan] * self.__len__()

        return {
            "index": index,  # for setting the encoder_params after training
            "images": self.data["images"][index],
            "labels": self.data["digits"][index],
            "encoder_params": encoder_params,
            "pred_params": self.data["pred_params"][index],
        }

    def save_encoder_params(self, index, encoder: nn.Module):
        self.data["encoder_params"][index] = encoder.state_dict()

    def save_pred_params(self, index, pred_params: torch.Tensor):
        self.data["pred_params"][index] = pred_params

    def save_dataset_to_disk(self, path=None):
        if path is None:
            path = self.processed_data_path

        torch.save(self.data, path)


class FractalDB_Dataset(Dataset):
    def __init__(self, root, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root
        self.transform = get_transform(transform)

        # Load all file names and their respective class names
        self.images_path = []
        self.label2name = {}

        for class_idx, class_name in enumerate(os.listdir(root)):
            class_dir = os.path.join(root, class_name)
            self.label2name[class_idx] = class_name

            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.images_path.append(
                        (os.path.join(class_name, img_name), class_idx)
                    )

    def label2name(self, label):
        return self.label2name[label]

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.images_path[index][0])
        image = Image.open(img_path).convert("L")
        label = self.images_path[index][1]

        if self.transform:
            image = self.transform(image)

        return {
            "index": index,  # for setting the encoder_params after training
            "images": image,
            "labels": label,
            # new version don't need this term
            # "encoder_params": encoder_params,
            # "pred_params": self.data["pred_params"][index],   # maybe we should just inference when we need? or seperate these with original dataset?
        }

    # def save_pred_params(self, index, pred_params: torch.Tensor):
    #     self.data["pred_params"][index] = pred_params


if __name__ == "__main__":
    processed_data_dir = r"./mnist/processed/"

    train_ds = MNIST_AAE_Dataset(os.path.join(processed_data_dir, "mnist_train.pt"))
    train_loader = DataLoader(train_ds, 4, False)
    samples = next(iter(train_loader))
    print(samples)
