import warnings
import os
import math
import numpy as np
import datasets
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as tvdatasets

from torch.utils.data import Dataset
from pycocotools.coco import COCO
from PIL import Image

from .base_provider import DataProvider
from .my_dataloader import MyRandomResizedCrop, MyDistributedSampler
from support.inat21_supclass import inat21_idx_to_class, inat21_class_to_idx
pwd = os.getcwd()
# print(pwd)

DATASET_ROOT = "./datasets"
MNIST_ROOT = DATASET_ROOT
FMNIST_ROOT = DATASET_ROOT
KMNIST_ROOT = DATASET_ROOT
EMNIST_ROOT = DATASET_ROOT
CALTECH_ROOT = DATASET_ROOT
CELEBA_ROOT = DATASET_ROOT
INAT_ROOT = os.path.join(DATASET_ROOT, "inaturalist")
IMAGENET_ROOT = os.path.join(DATASET_ROOT, "imagenet1k")
CIFAR100_ROOT = os.path.join(DATASET_ROOT, "cifar100")
CIFAR10_ROOT = os.path.join(DATASET_ROOT, "cifar10")
SVHN_ROOT = os.path.join(DATASET_ROOT, "SVHN")
# CELEBA_ROOT = os.path.join(DATASET_ROOT, "celeba")


__all__ = ["ImagenetALDataProvider", "CIFAR10ALDataProvider", "CIFAR100ALDataProvider", "SubsetSequentialSampler",
           "MNISTALDataProvider", "KMNISTALDataProvider", "FMNISTALDataProvider", "SVHNALDataProvider", "IMAGENET1KDataProvider",
           "CALTECH101ALDataProvider", "EMNISTLETALDataProvider", "EMNISTDIGALDataProvider", "CELEBADataProvider",
           "INATURALIST21SUPERDataProvider"]

from typing import Sequence
from torch.utils.data.sampler import Sampler


class SubsetSequentialSampler(Sampler[int]):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Args:
        indices (sequence): a sequence of indices
        generator (Generator): Generator used in sampling.
    """
    indices: Sequence[int]

    def __init__(self, indices: Sequence[int]) -> None:
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def split_balanced_valid_set_from_unlab(unlab_idx, valid_size, train_dataset):
    unlab_idx = np.array(unlab_idx)
    valid_idx = []
    class_target = dict()
    for ui in unlab_idx:
        _, target = train_dataset[ui]
        if target in class_target.keys():
            class_target[target].append(ui)
        else:
            class_target[target] = [ui]
    valid_size_per_class = valid_size // len(class_target.keys())
    for i in class_target.keys():
        samples = np.random.choice(class_target[i], size=valid_size_per_class, replace=False)
        valid_idx.extend(samples)

    valid_idx = np.array(valid_idx)
    unlab_idx = np.setdiff1d(unlab_idx, valid_idx)
    return unlab_idx, valid_idx


class ImagenetALDataProvider(DataProvider):
    DEFAULT_PATH = "/media/ying-peng/Data/dataset/tiny-imagenet-200"

    def __init__(
        self,
        save_path=None,
        train_batch_size=256,
        test_batch_size=512,
        init_lab_size=None,    # init. lab. set has the same size of valid. set
        n_worker=32,
        resize_scale=0.08,
        distort_color=None,
        image_size=224,
        num_replicas=None,
        rank=None,
        lab_idx=None,
        val_idx=None,
        unlab_idx=None,
        train_transform=None,
        valid_transform=None,
        split_valid_from_train=True,
        valid_size=None,
    ):
        warnings.filterwarnings("ignore")
        self._save_path = save_path
        self.extra_flag = None

        self.image_size = image_size  # int or list of int
        self.distort_color = "None" if distort_color is None else distort_color
        self.resize_scale = resize_scale

        self._valid_transform_dict = {}
        if not isinstance(self.image_size, int):
            # from ofa.utils.my_dataloader import MyDataLoader

            assert isinstance(self.image_size, list)
            self.image_size.sort()  # e.g., 160 -> 224
            MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
            MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)

            for img_size in self.image_size:
                self._valid_transform_dict[img_size] = self.build_valid_transform(
                    img_size
                )
            self.active_img_size = max(self.image_size)  # active resolution for test
            valid_transforms = self._valid_transform_dict[self.active_img_size] if valid_transform is None else valid_transform
            # train_loader_class = MyDataLoader  # randomly sample image size for each batch of training image
            train_loader_class = torch.utils.data.DataLoader
        else:
            self.active_img_size = self.image_size
            valid_transforms = self.build_valid_transform() if valid_transform is None else valid_transform
            train_loader_class = torch.utils.data.DataLoader

        train_dataset = self.train_dataset(self.build_train_transform() if train_transform is None else train_transform)

        if init_lab_size is not None:      # serve as the initially labeled set size
            if not isinstance(init_lab_size, int):
                assert isinstance(init_lab_size, float) and 0 < init_lab_size < 1
                init_lab_size = int(len(train_dataset) * init_lab_size)

            if lab_idx is None and unlab_idx is None:
                # valid_dataset = self.train_dataset(valid_transforms)
                unlab_idx, lab_idx = self.random_sample_valid_set(
                    len(train_dataset), init_lab_size
                )
                # assert set.union(set(lab_idx), set(unlab_idx)) == set(range(len(train_dataset)))
            self.lab_indexes = lab_idx
            self.unlab_indexes = unlab_idx
            self.val_indexes = None

            # judge 4 situation for validation set
            if val_idx is not None and split_valid_from_train: 
                self.val_indexes = val_idx
                val_sampler = SubsetSequentialSampler(
                    self.val_indexes
                )
                # is a subset of labeled set
                self.valid = train_loader_class(
                    train_dataset,
                    batch_size=train_batch_size,
                    sampler=val_sampler,
                    num_workers=n_worker,
                    pin_memory=False,
                )
            elif val_idx is not None and not split_valid_from_train: 
                # to check whether the dataset has a validation set
                self.val_indexes = val_idx
                val_sampler = SubsetSequentialSampler(
                    self.val_indexes
                )
                has_valid_dataset = getattr(self, "valid_dataset", None)
                if callable(has_valid_dataset):
                    valid_dataset = self.valid_dataset(train_transform)
                    self.valid = torch.utils.data.DataLoader(
                        valid_dataset,
                        batch_size=train_batch_size,
                        sampler=val_sampler,
                        num_workers=n_worker,
                        pin_memory=False,
                    )
                else:
                    raise RuntimeError(f"No validation set in {self.name} dataset.")
            elif val_idx is None and split_valid_from_train:
                assert valid_size is not None
                valid_idx = self.unlab_indexes[:valid_size]
                self.unlab_indexes = self.unlab_indexes[valid_size:]
                self.val_indexes = valid_idx
                # self.unlab_indexes, self.val_indexes = split_balanced_valid_set_from_unlab(self.unlab_indexes, valid_size, train_dataset)
                val_sampler = SubsetSequentialSampler(
                    self.val_indexes
                )
                # is a subset of labeled set
                self.valid = train_loader_class(
                    train_dataset,
                    batch_size=train_batch_size,
                    sampler=val_sampler,
                    num_workers=n_worker,
                    pin_memory=False,
                )
            else:
                # to check whether the dataset has a validation set
                has_valid_dataset = getattr(self, "valid_dataset", None)
                if callable(has_valid_dataset):
                    valid_dataset = self.valid_dataset(train_transform)
                    if valid_size is not None:
                        total_val_set = len(valid_dataset)
                        rand_val = np.random.permutation(np.arange(total_val_set, dtype=int)).tolist()
                        self.val_indexes = rand_val[0:valid_size]
                        val_sampler = SubsetSequentialSampler(
                            self.val_indexes
                        )
                    else:
                        val_sampler = None
                    self.valid = torch.utils.data.DataLoader(
                        valid_dataset,
                        batch_size=train_batch_size,
                        sampler=val_sampler,
                        num_workers=n_worker,
                        pin_memory=False,
                    )
                else:
                    # no valid set will be created
                    val_sampler = None
                    self.valid = None

            lab_sampler = SubsetSequentialSampler(
                self.lab_indexes
            )
            unlab_sampler = SubsetSequentialSampler(
                self.unlab_indexes
            )

            self.train = train_loader_class(
                train_dataset,
                batch_size=train_batch_size,
                sampler=lab_sampler,
                num_workers=n_worker,
                pin_memory=False,
            )
            self.unlab = train_loader_class(
                train_dataset,
                batch_size=train_batch_size,
                sampler=unlab_sampler,
                num_workers=n_worker,
                pin_memory=False,
            )

        else:
            raise ValueError("init_lab_size is None")

        test_dataset = self.test_dataset(valid_transforms)
        if num_replicas is not None:
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset, num_replicas, rank
            )
            self.test = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=test_batch_size,
                sampler=test_sampler,
                num_workers=n_worker,
                pin_memory=False,
            )
        else:
            self.test = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=test_batch_size,
                shuffle=True,
                num_workers=n_worker,
                pin_memory=False,
            )

        if self.valid is None:
            self.valid = self.test


    @staticmethod
    def name():
        return "imagenet"

    @property
    def data_shape(self):
        return 3, self.active_img_size, self.active_img_size  # C, H, W

    @property
    def n_classes(self):
        return 200

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = self.DEFAULT_PATH
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser("~/dataset/imagenet")
        return self._save_path

    @property
    def data_url(self):
        raise ValueError("unable to download %s" % self.name())

    def train_dataset(self, _transforms):
        return tvdatasets.ImageFolder(self.train_path, _transforms)

    def test_dataset(self, _transforms):
        return tvdatasets.ImageFolder(self.valid_path, _transforms)

    @property
    def train_path(self):
        return os.path.join(self.save_path, "train")

    @property
    def valid_path(self):
        return os.path.join(self.save_path, "val")

    @property
    def normalize(self):
        return transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

    def build_train_transform(self, image_size=None, print_log=True):
        if image_size is None:
            image_size = self.image_size
        if print_log:
            print(
                "Color jitter: %s, resize_scale: %s, img_size: %s"
                % (self.distort_color, self.resize_scale, image_size)
            )

        if isinstance(image_size, list):
            resize_transform_class = MyRandomResizedCrop
            print(
                "Use MyRandomResizedCrop: %s, \t %s"
                % MyRandomResizedCrop.get_candidate_image_size(),
                "sync=%s, continuous=%s"
                % (
                    MyRandomResizedCrop.SYNC_DISTRIBUTED,
                    MyRandomResizedCrop.CONTINUOUS,
                ),
            )
        else:
            resize_transform_class = transforms.RandomResizedCrop

        # random_resize_crop -> random_horizontal_flip
        train_transforms = [
            resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
            transforms.RandomHorizontalFlip(),
        ]

        # color augmentation (optional)
        color_transform = None
        if self.distort_color == "torch":
            color_transform = transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
            )
        elif self.distort_color == "tf":
            color_transform = transforms.ColorJitter(
                brightness=32.0 / 255.0, saturation=0.5
            )
        if color_transform is not None:
            train_transforms.append(color_transform)

        train_transforms += [
            transforms.ToTensor(),
            self.normalize,
        ]

        train_transforms = transforms.Compose(train_transforms)
        return train_transforms

    def build_valid_transform(self, image_size=None):
        if image_size is None:
            image_size = self.active_img_size
        return transforms.Compose(
            [
                transforms.Resize(int(math.ceil(image_size / 0.875))),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def assign_active_img_size(self, new_img_size):
        self.active_img_size = new_img_size
        if self.active_img_size not in self._valid_transform_dict:
            self._valid_transform_dict[
                self.active_img_size
            ] = self.build_valid_transform()
        # change the transform of the valid and test set
        self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
        self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]

    def build_sub_train_loader(
        self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None
    ):
        # used for resetting BN running statistics
        if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None:
            if num_worker is None:
                num_worker = self.train.num_workers

            n_samples = len(self.train.dataset)
            g = torch.Generator()
            g.manual_seed(DataProvider.SUB_SEED)
            rand_indexes = torch.randperm(n_samples, generator=g).tolist()

            new_train_dataset = self.train_dataset(
                self.build_train_transform(
                    image_size=self.active_img_size, print_log=False
                )
            )
            chosen_indexes = rand_indexes[:n_images]
            if num_replicas is not None:
                sub_sampler = MyDistributedSampler(
                    new_train_dataset,
                    num_replicas,
                    rank,
                    True,
                    np.array(chosen_indexes),
                )
            else:
                sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                    chosen_indexes
                )
            sub_data_loader = torch.utils.data.DataLoader(
                new_train_dataset,
                batch_size=batch_size,
                sampler=sub_sampler,
                num_workers=num_worker,
                pin_memory=False,
            )
            self.__dict__["sub_train_%d" % self.active_img_size] = []
            for images, labels in sub_data_loader:
                self.__dict__["sub_train_%d" % self.active_img_size].append(
                    (images, labels)
                )
        return self.__dict__["sub_train_%d" % self.active_img_size]


class CIFAR10ALDataProvider(ImagenetALDataProvider):

    @staticmethod
    def name():
        return "cifar10"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CIFAR10_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.CIFAR10(root=CIFAR10_ROOT, transform=_transforms, train=True, download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.CIFAR10(root=CIFAR10_ROOT, transform=_transforms, train=False, download=True)

    @property
    def train_path(self):
        return self.save_path

    @property
    def valid_path(self):
        return self.save_path


class CIFAR100ALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "cifar100"

    @property
    def n_classes(self):
        return 100

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CIFAR100_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.CIFAR100(root=CIFAR100_ROOT, transform=_transforms, train=True, download=False)

    def test_dataset(self, _transforms):
        return tvdatasets.CIFAR100(root=CIFAR100_ROOT, transform=_transforms, train=False, download=False)

    @property
    def train_path(self):
        return self.save_path

    @property
    def valid_path(self):
        return self.save_path


class MNISTALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "mnist"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = MNIST_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.MNIST(root=MNIST_ROOT, transform=_transforms, train=True, download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.MNIST(root=MNIST_ROOT, transform=_transforms, train=False, download=True)


class FMNISTALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "fmnist"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = FMNIST_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.FashionMNIST(root=FMNIST_ROOT, transform=_transforms, train=True, download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.FashionMNIST(root=FMNIST_ROOT, transform=_transforms, train=False, download=True)


class KMNISTALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "kmnist"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = KMNIST_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.KMNIST(root=KMNIST_ROOT, transform=_transforms, train=True, download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.KMNIST(root=KMNIST_ROOT, transform=_transforms, train=False, download=True)


class EMNISTDIGALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "emnistdig"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = EMNIST_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.EMNIST(root=EMNIST_ROOT, transform=_transforms, train=True, download=True, split="digits")

    def test_dataset(self, _transforms):
        return tvdatasets.EMNIST(root=EMNIST_ROOT, transform=_transforms, train=False, download=True, split="digits")


class EMNISTLETALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "emnistlet"

    @property
    def n_classes(self):
        return 27

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = EMNIST_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.EMNIST(root=EMNIST_ROOT, transform=_transforms, train=True, download=True, split="letters")

    def test_dataset(self, _transforms):
        return tvdatasets.EMNIST(root=EMNIST_ROOT, transform=_transforms, train=False, download=True, split="letters")


class SVHNALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "svhn"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = SVHN_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.SVHN(root=SVHN_ROOT, transform=_transforms, split="train", download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.SVHN(root=SVHN_ROOT, transform=_transforms, split="test", download=True)


class CALTECH101ALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "caltech101"

    @property
    def n_classes(self):
        return 101

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CALTECH_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        if '_train_idx' not in self.__dict__.keys() or '_test_idx' not in self.__dict__.keys():
            self._load_dataset()
        ori_dataset = tvdatasets.Caltech101(root=CALTECH_ROOT, transform=_transforms, download=True)
        ori_dataset.index = [ori_dataset.index[i] for i in self._train_idx]
        ori_dataset.y = [ori_dataset.y[i] for i in self._train_idx]
        return ori_dataset

    def test_dataset(self, _transforms):
        if '_train_idx' not in self.__dict__.keys() or '_test_idx' not in self.__dict__.keys():
            self._load_dataset()
        ori_dataset = tvdatasets.Caltech101(root=CALTECH_ROOT, transform=_transforms, download=True)
        ori_dataset.index = [ori_dataset.index[i] for i in self._test_idx]
        ori_dataset.y = [ori_dataset.y[i] for i in self._test_idx]
        return ori_dataset

    def _load_dataset(self):
        ori = tvdatasets.Caltech101(root=CALTECH_ROOT, download=True)
        dataset_size = len(ori.index)
        dataset_idx = np.arange(dataset_size)
        np.random.seed(0)
        np.random.shuffle(dataset_idx)
        train_size = 0.7
        self._train_idx = dataset_idx[:round(train_size*dataset_size)]
        self._test_idx = dataset_idx[round(train_size*dataset_size):]


class CELEBADataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "celeba"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = CELEBA_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        return tvdatasets.CelebA(root=CELEBA_ROOT, transform=_transforms, split="train", target_type="landmarks", download=True)

    def test_dataset(self, _transforms):
        return tvdatasets.CelebA(root=CELEBA_ROOT, transform=_transforms, split="test", target_type="landmarks", download=True)


class InatDataset(Dataset):
    """
    Dataset for the iNaturalist challenge 2021.
    More info on:
      * https://github.com/visipedia/inat_comp/tree/master/2021
      * https://www.kaggle.com/c/inaturalist-2021/overview/description
    """

    def __init__(self, transform, split:str="train"):
        self.root = INAT_ROOT
        self.image_path = self.root
        self.transform = transform

        self.coco = COCO(os.path.join(self.root, f"{split}.json"))
        # Create the dataset based on image IDs (images only loaded on demand)
        self.ids = list(sorted(self.coco.getImgIds()))
        self.idx_to_supclass = inat21_idx_to_class
        self.supclass_to_idx = inat21_class_to_idx
        print("Loaded dataset with size {}".format(len(self.ids)))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get the image ID from the dataset at the given index
        image_id = self.ids[idx]
        # Load the annotation IDs for the given image
        annotation_ids = self.coco.getAnnIds(imgIds=image_id)
        # Load the actual annotations for the given IDs
        target = self.coco.loadAnns(annotation_ids)[0]
        # Query the label
        label = target["category_id"]
        sup_label = self.coco.cats[label]["supercategory"]

        # Query the image filename and other parameters
        image_filename = self.coco.loadImgs(image_id)[0]["file_name"]
        # image_lat = self.coco.loadImgs(image_id)[0]["latitude"]
        # image_lon = self.coco.loadImgs(image_id)[0]["longitude"]
        # image_loc_uncert = self.coco.loadImgs(image_id)[0]["location_uncertainty"]

        # Load the image using PIL using the RGB color space
        image_path = os.path.join(self.image_path, image_filename)
        image = Image.open(image_path).convert("RGB")
        # Transform the image (augmentation, to Torch, ...)
        image = self.transform(image)

        # return {
        #     "image": image,
        #     "label": label,
        #     # "metadata": {
        #     #     "lat": image_lat,
        #     #     "lon": image_lon,
        #     #     "loc_uncert": image_loc_uncert
        #     # }
        # }
        return image, self.supclass_to_idx[sup_label]
    

class INATURALIST21SUPERDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "inaturalist21"

    @property
    def n_classes(self):
        return 11

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = INAT_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        # return tvdatasets.INaturalist(root=INAT_ROOT, version="2021_train", transform=_transforms, 
        #                             target_type="phylum", download=False)
        return InatDataset(_transforms, split="train")
    
    def test_dataset(self, _transforms):
        # return tvdatasets.INaturalist(root=INAT_ROOT, version="2021_valid", transform=_transforms, 
        #                             target_type="phylum", download=False)
        return InatDataset(_transforms, split="train")


class MYIMAGENETDATASET(torch.utils.data.Dataset):
    def __init__(self, huggingface_dataset, transforms):
        self.hf_dataset = huggingface_dataset
        self.transforms = transforms
        self.root = IMAGENET_ROOT

    def __len__(self):
        return len(self.hf_dataset)
    
    def __getitem__(self, index):
        item = self.hf_dataset.__getitem__(int(index))
        img = self.transforms(item["image"])
        target = item["label"]
        return img, target


class IMAGENET1KDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "imagenet1k"

    @property
    def n_classes(self):
        return 1000

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = IMAGENET_ROOT
            if not os.path.exists(self._save_path):
                self._save_path = os.path.expanduser(f"~/dataset/{self.name()}")
        return self._save_path

    def train_dataset(self, _transforms):
        if self.extra_flag is None:
            self.ds = datasets.load_from_disk(DATASET_ROOT)
            self.extra_flag = True
        return MYIMAGENETDATASET(self.ds["train"], _transforms)
    
    def test_dataset(self, _transforms):
        if self.extra_flag is None:
            self.ds = datasets.load_from_disk(DATASET_ROOT)
            self.extra_flag = True
        return MYIMAGENETDATASET(self.ds["validation"], _transforms)
    