import os
import json
import torch
import random
import numpy as np
from einops import repeat
from typing import List, Any
from torch.utils.data import Dataset
import torchvision.transforms as transforms

TORCH_IMAGE_CLASSIFCATON_DATASETS = [
   'svhn',
   'cifar100',
   'gtsrb',
   'daimlerpedcls',
   'omniglot',  # 1: fg, 0: bg
   'ucf101',
   'aircraft',
   'dtd',
   'vgg-flowers',
]
to_PILimage = transforms.ToPILImage()


class TORCH_CLS_Dataset(Dataset):
    def __init__(
        self,
        dataset_name,
        split,
        transforms=None,
        path="./datasets/torch_datasets/"
    ) -> None:
        super().__init__()
        assert dataset_name in TORCH_IMAGE_CLASSIFCATON_DATASETS, f'Not supported torch-dataset: {dataset_name}!'
        assert split in ["train", "test", "validation"], f'Not supported split: {split}!'
        infos = json.load(open(os.path.join(path, dataset_name, split, "info.json")))
        self._image, self._label = \
            np.load(infos["image"].replace("./torch_datasets/visual_domain_decathlon/", path + "/")), \
            np.load(infos["label"].replace("./torch_datasets/visual_domain_decathlon/", path + "/"))
        self.num_classes = infos["num_classes"]
        self.image_shape = infos["image_shape"]
        # print(self.image_shape, self.num_classes)
        self.dataset_name = infos["dataset_name"]
        self.split = infos["split"]
        self.transforms = transforms
        # self.sleep = False

    def set_transforms(self, transforms):
        self.transforms = transforms
        
    # def change_mode(self):
    #     self.sleep = not self.sleep

    def __getitem__(self, index, random_choice=False):
        img, label = self._image[index], self._label[index]
        img = transforms.ToTensor()(img)
        b, _, _ = img.shape
        # print(b)
        if b == 1:
            img = repeat(img, "c h w -> (repeat c) h w", repeat=3)
        img = to_PILimage(img)
        if self.transforms is not None:
            img = self.transforms(img)
        # if self.sleep and not random_choice:
        #     confuse_img, confuse_label = self.__getitem__(random.randint(0, len(self) - 1), random_choice=True)
        #     img = 0.8 * img + 0.2 * confuse_img
        return img, (label, 0)

    def __len__(self):
        return self._label.shape[0]
    

if __name__ == '__main__':
    for name in TORCH_IMAGE_CLASSIFCATON_DATASETS:
        for mod in ["train", "test"]:
            dataset = TORCH_CLS_Dataset(name, mod)
            dataset.set_transforms(transforms.ToTensor())
            for x, y in dataset:
                print(dataset.dataset_name, x.shape, len(dataset), dataset.num_classes)
                break