import collections
import os
import torch
import torchvision
import torchvision.transforms as transforms
from typing import Optional
import torchvision.datasets as datasets
from typing import Optional, Callable, Tuple, Any, List
from torchvision.datasets.folder import default_loader
import numpy as np
from support_alignment.core import utils
import copy
from RandAugment import RandAugment


class DatasetWithAttributesWrapper(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.attributes = dict()

    def __getitem__(self, item):
        return self.base_dataset[item]

    def __len__(self):
        return len(self.base_dataset)


class SubsetDatasetWrapper(DatasetWithAttributesWrapper):
    def __init__(self, base_dataset, indices):
        super(SubsetDatasetWrapper, self).__init__(base_dataset)
        self.indices = indices

        if isinstance(base_dataset, DatasetWithAttributesWrapper):
            for key, attr_list in base_dataset.attributes.items():
                self.attributes[key] = [attr_list[i] for i in self.indices]

    def __getitem__(self, item):
        return self.base_dataset[self.indices[item]]

    def __len__(self):
        return len(self.indices)


class TransformDatasetWrapper(DatasetWithAttributesWrapper):
    def __init__(self, base_dataset, transform):
        super(TransformDatasetWrapper, self).__init__(base_dataset)
        self.transform = transform
        if isinstance(base_dataset, DatasetWithAttributesWrapper):
            self.attributes = {
                key: attr_list for key, attr_list in base_dataset.attributes.items()
            }

    def __getitem__(self, item):
        x, y = self.base_dataset[item]
        x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.base_dataset)


class SentryDatasetWrapper(DatasetWithAttributesWrapper):
    def __init__(
        self, base_dataset, transform, committee_size=3, N=1, M=2.0, is_mnist=False
    ):
        super(SentryDatasetWrapper, self).__init__(base_dataset)
        self.transform = transform
        if isinstance(base_dataset, DatasetWithAttributesWrapper):
            self.attributes = {
                key: attr_list for key, attr_list in base_dataset.attributes.items()
            }
        self.rand_aug_transforms = copy.deepcopy(self.transform)
        self.is_mnist = is_mnist

        # dataset_name = base_dataset.attributes.get("name", None)
        if is_mnist:  # See if this can be dropped entirely without losing performance
            self.transform.transforms.insert(
                0, torchvision.transforms.Grayscale(num_output_channels=1)
            )
            self.rand_aug_transforms.transforms.insert(
                0, torchvision.transforms.Grayscale(num_output_channels=1)
            )

        self.committee_size = committee_size
        self.ra_obj = RandAugment(N, M)
        self.rand_aug_transforms.transforms.insert(0, self.ra_obj)

    def __getitem__(self, item):
        x, y = self.base_dataset[item]

        if self.is_mnist:
            x = x.convert("RGB")

        base_x = self.transform(x)
        rand_aug_lst = [self.rand_aug_transforms(x) for _ in range(self.committee_size)]
        return base_x, rand_aug_lst, y
        # return base_x, y

    def __len__(self):
        return len(self.base_dataset)


def split_dataset(dataset, split_fraction=0.2, seed=None):

    np.random.seed(seed)

    classes = dataset.attributes.get("classes", None)

    classes = np.array(classes)

    counter = collections.Counter(classes)
    num_classes = len(counter)

    idx_per_label = []
    for i in range(num_classes):
        idx_per_label.append(np.where(classes == i)[0])

    source_idx = []
    target_idx = []

    # make sure class balance split
    for i in range(num_classes):
        source_idx.extend(
            np.random.choice(
                idx_per_label[i],
                int((1 - split_fraction) * len(idx_per_label[i])),
                replace=False,
            )
        )
        target_idx.extend(
            np.setdiff1d(idx_per_label[i], source_idx, assume_unique=True)
        )

    return SubsetDatasetWrapper(dataset, source_idx), SubsetDatasetWrapper(
        dataset, target_idx
    )


class ClassSubsetDatasetWrapper(DatasetWithAttributesWrapper):
    def __init__(self, base_dataset, class_subset):
        super(ClassSubsetDatasetWrapper, self).__init__(base_dataset)

        classes = None
        if hasattr(base_dataset, "attributes"):
            classes = base_dataset.attributes.get("classes", None)
        if classes is None:
            classes = [int(y) for _, y in base_dataset]

        self.class_mapping = {class_id: i for i, class_id in enumerate(class_subset)}
        self.indices = [
            i for i, y in enumerate(classes) if int(y) in self.class_mapping
        ]
        if isinstance(base_dataset, DatasetWithAttributesWrapper):
            for key, attr_list in base_dataset.attributes.items():
                self.attributes[key] = [attr_list[i] for i in self.indices]
        self.attributes["classes"] = [
            self.class_mapping[classes[i]] for i in self.indices
        ]

    def __getitem__(self, item):
        x, y = self.base_dataset[self.indices[item]]
        y = self.class_mapping[int(y)]
        return x, y

    def __len__(self):
        return len(self.indices)


class MultipleDomainDataset:
    # Adapted from https://github.com/facebookresearch/DomainBed/blob/master/domainbed/datasets.py
    # attributes:
    #   input_shape
    #   num_classes
    #   datasets
    ENVIRONMENTS = tuple()

    def __init__(self):
        self.input_shape = tuple()
        self.num_classes = 0
        self.datasets = list()
        self.class_names = None

    def __getitem__(self, index):
        return self.datasets[index]

    def __len__(self):
        return len(self.datasets)

    def mod_class_subset(self, class_subset):
        for class_index in class_subset:
            if class_index < 0 or class_index >= self.num_classes:
                raise ValueError(
                    "Class subset index {0} is not in range [0, {1})".format(
                        class_index, self.num_classes
                    )
                )

        self.num_classes = len(class_subset)
        self.new_datasets = [
            ClassSubsetDatasetWrapper(dataset, class_subset)
            if dataset is not None
            else None
            for dataset in self.datasets
        ]
        for i in range(len(self.new_datasets)):
            if hasattr(self.datasets[i], "name"):
                self.new_datasets[i].name = self.datasets[i].name
            if hasattr(self.datasets[i], "train_transform"):
                self.new_datasets[i].train_transform = self.datasets[i].train_transform
            if hasattr(self.datasets[i], "eval_transform"):
                self.new_datasets[i].eval_transform = self.datasets[i].eval_transform
        self.datasets = self.new_datasets

        if self.class_names is not None:
            self.class_names = [self.class_names[i] for i in class_subset]

    def apply_mod(self, mod_name, mod_args):
        mod_fn = getattr(self, "mod_{0}".format(mod_name))
        mod_fn(*mod_args)


class MNIST_USPS(MultipleDomainDataset):
    ENVIRONMENTS = ("MNIST", "USPS")

    def __init__(self, root):
        super().__init__()
        self.input_shape = (1, 28, 28)
        self.num_classes = 10
        self.class_names = ["{0}".format(i) for i in range(10)]

        norm_transform = transforms.Normalize(mean=[0.5], std=[0.5])

        inv_norm_transform = transforms.Normalize(mean=[-1.0], std=[2.0])

        mnist_transform = transforms.Compose([transforms.ToTensor(), norm_transform])

        mnist_tr = torchvision.datasets.MNIST(root, train=True, download=True)
        mnist_te = torchvision.datasets.MNIST(root, train=False, download=True)

        mnist_combined = torch.utils.data.ConcatDataset([mnist_tr, mnist_te])

        mnist_classes = torch.cat((mnist_tr.targets, mnist_te.targets))
        mnist_combined = DatasetWithAttributesWrapper(mnist_combined)
        mnist_combined.attributes["classes"] = [int(y) for y in mnist_classes]
        mnist_combined.name = "mnist"

        mnist_combined.train_transform = mnist_transform
        mnist_combined.eval_transform = mnist_transform

        usps_transform = transforms.Compose(
            [transforms.Resize((28, 28)), transforms.ToTensor(), norm_transform]
        )

        usps_tr = torchvision.datasets.USPS(root, train=True, download=True)
        usps_te = torchvision.datasets.USPS(root, train=False, download=True)

        usps_combined = torch.utils.data.ConcatDataset([usps_tr, usps_te])
        usps_classes = torch.cat(
            (torch.tensor(usps_tr.targets), torch.tensor(usps_te.targets))
        )
        usps_combined = DatasetWithAttributesWrapper(usps_combined)
        usps_combined.attributes["classes"] = [int(y) for y in usps_classes]
        usps_combined.name = "usps"
        usps_combined.train_transform = usps_transform
        usps_combined.eval_transform = usps_transform

        self.datasets = [mnist_combined, usps_combined]

        self.data_params = {"inv_norm_transform": inv_norm_transform}


class CIFAR_STL(MultipleDomainDataset):
    ENVIRONMENTS = ("CIFAR", "STL")

    def __init__(self, root):
        super().__init__()
        self.input_shape = (3, 32, 32)
        self.num_classes = 9
        self.class_names = [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "horse",
            "ship",
            "truck",
        ]

        norm_transform = transforms.Normalize(mean=[0.5], std=[0.5])
        inv_norm_transform = transforms.Normalize(mean=[-1.0], std=[2.0])

        cifar_transform = transforms.Compose([transforms.ToTensor(), norm_transform])

        cifar_tr = torchvision.datasets.CIFAR10(root, train=True, download=True)
        cifar_te = torchvision.datasets.CIFAR10(root, train=False, download=True)

        cifar_class_mapping = torch.tensor(
            [0, 1, 2, 3, 4, 5, -1, 6, 7, 8], dtype=torch.long
        )
        cifar_tr_classes = cifar_class_mapping[cifar_tr.targets]
        cifar_te_classes = cifar_class_mapping[cifar_te.targets]

        cifar_tr_subset_index = torch.nonzero(cifar_tr_classes >= 0, as_tuple=True)[0]
        cifar_te_subset_index = torch.nonzero(cifar_te_classes >= 0, as_tuple=True)[0]
        cifar_tr.data = cifar_tr.data[cifar_tr_subset_index.numpy()]
        cifar_te.data = cifar_te.data[cifar_te_subset_index.numpy()]
        cifar_tr_classes = cifar_tr_classes[cifar_tr_subset_index]
        cifar_te_classes = cifar_te_classes[cifar_te_subset_index]
        cifar_tr.targets = cifar_tr_classes.tolist()
        cifar_te.targets = cifar_te_classes.tolist()

        cifar_combined = torch.utils.data.ConcatDataset([cifar_tr, cifar_te])
        cifar_classes = torch.cat((cifar_tr_classes, cifar_te_classes))
        cifar_combined = DatasetWithAttributesWrapper(cifar_combined)
        cifar_combined.attributes["classes"] = [int(y) for y in cifar_classes]
        cifar_combined.train_transform = cifar_transform
        cifar_combined.eval_transform = cifar_transform

        stl_transform = transforms.Compose(
            [transforms.Resize((32, 32)), transforms.ToTensor(), norm_transform]
        )

        stl_tr = torchvision.datasets.STL10(root, split="train", download=True)
        stl_te = torchvision.datasets.STL10(root, split="test", download=True)

        stl_class_mapping = torch.tensor(
            [0, 2, 1, 3, 4, 5, 6, -1, 7, 8], dtype=torch.long
        )
        stl_tr_classes = stl_class_mapping[stl_tr.labels.tolist()]
        stl_te_classes = stl_class_mapping[stl_te.labels.tolist()]

        stl_tr_subset_index = torch.nonzero(stl_tr_classes >= 0, as_tuple=True)[0]
        stl_te_subset_index = torch.nonzero(stl_te_classes >= 0, as_tuple=True)[0]
        stl_tr.data = stl_tr.data[stl_tr_subset_index.numpy()]
        stl_te.data = stl_te.data[stl_te_subset_index.numpy()]
        stl_tr_classes = stl_tr_classes[stl_tr_subset_index]
        stl_te_classes = stl_te_classes[stl_te_subset_index]
        stl_tr.labels = stl_tr_classes.numpy()
        stl_te.labels = stl_te_classes.numpy()

        stl_combined = torch.utils.data.ConcatDataset([stl_tr, stl_te])
        stl_classes = torch.cat((stl_tr_classes, stl_te_classes))
        stl_combined = DatasetWithAttributesWrapper(stl_combined)
        stl_combined.attributes["classes"] = [int(y) for y in stl_classes]
        stl_combined.train_transform = stl_transform
        stl_combined.eval_transform = stl_transform

        self.datasets = [cifar_combined, stl_combined]

        self.data_params = {"inv_norm_transform": inv_norm_transform}


# Adapted from https://github.com/facebookresearch/DomainBed/blob/master/domainbed/datasets.py
class MultipleEnvironmentImageFolder(MultipleDomainDataset):
    def __init__(self, root, train_aug=None, eval_aug=None):
        super().__init__()
        env_dirs = [f.name for f in os.scandir(root) if f.is_dir()]
        self.env_dirs = sorted(env_dirs)

        norm_transform = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

        inv_norm_transform = transforms.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1.0 / 0.229, 1.0 / 0.224, 1.0 / 0.225],
        )

        if train_aug is None:
            train_aug = transforms.Compose(
                [
                    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
                    transforms.RandomGrayscale(),
                ]
            )

        self.train_transform = transforms.Compose(
            [train_aug, transforms.ToTensor(), norm_transform]
        )

        if eval_aug is None:
            eval_aug = transforms.Resize((224, 224))

        self.eval_transform = transforms.Compose(
            [eval_aug, transforms.ToTensor(), norm_transform]
        )

        self.data_params = {"inv_norm_transform": inv_norm_transform}

        self.datasets = []
        num_classes = 0
        self.class_names = None
        for i, environment in enumerate(self.env_dirs):
            path = os.path.join(root, environment)
            base_dataset = torchvision.datasets.ImageFolder(path, transform=None)
            num_classes = len(base_dataset.classes)
            self.class_names = base_dataset.classes
            # Wrap dataset and extract class labels for all samples to save time
            env_dataset = base_dataset
            env_dataset = DatasetWithAttributesWrapper(env_dataset)
            env_dataset.attributes["classes"] = [
                int(label) for _, label in base_dataset.samples
            ]

            self.datasets.append(env_dataset)

        self.input_shape = None
        self.num_classes = num_classes


class VisDA17(MultipleEnvironmentImageFolder):
    ENVIRONMENTS = ["Train", "Val"]

    # Dataset root must contain visda17 folder with only train and validation directory
    # test split should be saved in a separated directory
    def __init__(self, root):
        self.dir = os.path.join(root, "visda17")

        train_aug = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
            ]
        )

        eval_aug = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.CenterCrop(224),
            ]
        )

        super().__init__(self.dir, train_aug=train_aug, eval_aug=eval_aug)
        if self.env_dirs != ["train", "validation"]:
            raise ValueError(
                """VisDA17 folder must contain train and validation folders and no other folders.
                Found: {0}""".format(
                    self.env_dirs
                )
            )
        self.input_shape = (3, 224, 224)

        for domain_dataset in self.datasets:
            domain_dataset.train_transform = self.train_transform
            domain_dataset.eval_transform = self.eval_transform


class ImageList(datasets.VisionDataset):
    """A generic Dataset class for domain adaptation in image classification

    Parameters:
        - **root** (str): Root directory of dataset
        - **classes** (List[str]): The names of all the classes
        - **data_list_file** (str): File to read the image list from.
        - **transform** (callable, optional): A function/transform that  takes in an PIL image \
            and returns a transformed version. E.g, ``transforms.RandomCrop``.
        - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.

    .. note:: In `data_list_file`, each line 2 values in the following format.
        ::
            source_dir/dog_xxx.png 0
            source_dir/cat_123.png 1
            target_dir/dog_xxy.png 0
            target_dir/cat_nsdf3.png 1

        The first value is the relative path of an image, and the second value is the label of the corresponding image.
        If your data_list_file has different formats, please over-ride `parse_data_file`.
    """

    def __init__(
        self,
        root: str,
        classes: List[str],
        data_list_file: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        subsample: Optional[bool] = False,
        # sentry params
        val_transform=None,
        sentry=False,
        committee_size=3,
        N=1,
        M=2.0,
    ):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.data = self.parse_data_file(data_list_file)
        self.classes = classes
        self.class_to_idx = {
            cls: idx for idx, clss in enumerate(self.classes) for cls in clss
        }

        # sentry
        self.sentry = sentry
        if self.sentry:
            self.rand_aug_transforms = copy.deepcopy(val_transform)
            self.committee_size = committee_size
            self.ra_obj = RandAugment(N, M)
            self.rand_aug_transforms.transforms.insert(0, self.ra_obj)

        self.labels_to_idx = self.get_labels_to_idx(self.data)
        self.loader = default_loader

        if subsample:
            self.data = self.subsample(self.data, self.labels_to_idx)
            self.labels_to_idx = self.get_labels_to_idx(self.data)

        n = len(self.data)
        self.proportion = [
            len(self.labels_to_idx[key]) / n
            for key in sorted(self.labels_to_idx.keys())
        ]

        inv_norm_transform = transforms.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1.0 / 0.229, 1.0 / 0.224, 1.0 / 0.225],
        )
        self.data_params = {"inv_norm_transform": inv_norm_transform}

        self.input_shape = (3, 224, 224)  # note: THIS IS NOT TRUE IN REALITY

    def subsample(self, data, labels_to_idx):
        keep_idx = []
        num_classes = len(labels_to_idx)
        for label in sorted(labels_to_idx.keys()):
            if label < num_classes // 2:
                keep_idx.extend(
                    np.random.choice(
                        labels_to_idx[label],
                        int(0.3 * len(labels_to_idx[label])),
                        replace=False,
                    ).tolist()
                )
            else:
                keep_idx.extend(labels_to_idx[label])
        keep_idx = set(keep_idx)
        temp = []
        for i in range(len(data)):
            if i in keep_idx:
                temp.append(data[i])
        return temp

    def get_labels_to_idx(self, data):
        labels_to_idx = {}
        for idx, path in enumerate(data):
            label = path[1]
            if label not in labels_to_idx:
                labels_to_idx[label] = [idx]
            else:
                labels_to_idx[label].append(idx)
        return labels_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, int]:
        """
        Parameters:
            - **index** (int): Index
            - **return** (tuple): (image, target) where target is index of the target class.
        """
        path, target = self.data[index]
        img = self.loader(path)
        if self.transform is not None:
            base_img = self.transform(img)
        if self.target_transform is not None and target is not None:
            target = self.target_transform(target)

        if self.sentry:
            rand_aug_lst = [
                self.rand_aug_transforms(img) for _ in range(self.committee_size)
            ]
            return base_img, rand_aug_lst, target
        else:
            return base_img, target

    def __len__(self) -> int:
        return len(self.data)

    def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:
        """Parse file to data list

        Parameters:
            - **file_name** (str): The path of data file
            - **return** (list): List of (image path, class_index) tuples
        """
        with open(file_name, "r") as f:
            data_list = []
            for line in f.readlines():
                path, target = line.split()
                if not os.path.isabs(path):
                    path = os.path.join(self.root, path)
                target = int(target)
                data_list.append((path, target))
        return data_list

    @property
    def num_classes(self) -> int:
        """Number of classes"""
        return len(self.classes)


class Office31(ImageList):
    """Office31 Dataset.

    Parameters:
        - **root** (str): Root directory of dataset
        - **task** (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
            ``'D'``: dslr and ``'W'``: webcam.
        - **download** (bool, optional): If true, downloads the dataset from the internet and puts it \
            in root directory. If dataset is already downloaded, it is not downloaded again.
        - **transform** (callable, optional): A function/transform that  takes in an PIL image and returns a \
            transformed version. E.g, ``transforms.RandomCrop``.
        - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.

    .. note:: In `root`, there will exist following files after downloading.
        ::
            amazon/
                images/
                    backpack/
                        *.jpg
                        ...
            dslr/
            webcam/
            image_list/
                amazon.txt
                dslr.txt
                webcam.txt
    """

    download_list = [
        (
            "image_list",
            "image_list.zip",
            "https://drive.google.com/uc?export=download&id=1JGjr1bYe0oYkso6prudvKYoFi0FOJcrE",
        ),
        (
            "amazon",
            "amazon.tgz",
            "https://drive.google.com/uc?export=download&id=1xq7gPW14FSLlrerR9nDCryKeSBHNNeFp",
        ),
        (
            "dslr",
            "dslr.tgz",
            "https://drive.google.com/uc?export=download&id=14F7HWvclPehy38aVMNNap1oFLVitUAgA",
        ),
        (
            "webcam",
            "webcam.tgz",
            "https://drive.google.com/uc?export=download&id=11OW_6J7kss6nlgKmbX6jWWQbO3yTmHNV",
        ),
    ]
    image_list = {
        "A": "image_list/amazon.txt",
        "D": "image_list/dslr.txt",
        "W": "image_list/webcam.txt",
    }
    CLASSES = [
        "back_pack",
        "bike",
        "bike_helmet",
        "bookcase",
        "bottle",
        "calculator",
        "desk_chair",
        "desk_lamp",
        "desktop_computer",
        "file_cabinet",
        "headphones",
        "keyboard",
        "laptop_computer",
        "letter_tray",
        "mobile_phone",
        "monitor",
        "mouse",
        "mug",
        "paper_notebook",
        "pen",
        "phone",
        "printer",
        "projector",
        "punchers",
        "ring_binder",
        "ruler",
        "scissors",
        "speaker",
        "stapler",
        "tape_dispenser",
        "trash_can",
    ]

    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
        assert task in self.image_list
        root = os.path.join(root, "office31")
        data_list_file = os.path.join(root, self.image_list[task])

        if download:
            list(map(lambda args: download_data(root, *args), self.download_list))
        else:
            list(
                map(
                    lambda file_name: check_exits(root, file_name[0]),
                    self.download_list,
                )
            )

        super(Office31, self).__init__(
            root, Office31.CLASSES, data_list_file=data_list_file, **kwargs
        )


def check_exits(root: str, file_name: str):
    """Check whether `file_name` exists under directory `root`."""
    if not os.path.exists(os.path.join(root, file_name)):
        print("Dataset directory {} not found under {}".format(file_name, root))
        exit(-1)


class OfficeHome(ImageList):
    """`OfficeHome <http://hemanthdv.org/OfficeHome-Dataset/>`_ Dataset.

    Parameters:
        - **root** (str): Root directory of dataset
        - **task** (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \
            ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World.
        - **download** (bool, optional): If true, downloads the dataset from the internet and puts it \
            in root directory. If dataset is already downloaded, it is not downloaded again.
        - **transform** (callable, optional): A function/transform that  takes in an PIL image and returns a \
            transformed version. E.g, ``transforms.RandomCrop``.
        - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.

    .. note:: In `root`, there will exist following files after downloading.
        ::
            Art/
                Alarm_Clock/*.jpg
                ...
            Clipart/
            Product/
            Real_World/
            image_list/
                Art.txt
                Clipart.txt
                Product.txt
                Real_World.txt
    """

    download_list = [
        (
            "image_list",
            "image_list.zip",
            "https://cloud.tsinghua.edu.cn/f/ca3a3b6a8d554905b4cd/?dl=1",
        ),
        (
            "Art",
            "Art.tgz",
            "https://cloud.tsinghua.edu.cn/f/4691878067d04755beab/?dl=1",
        ),
        (
            "Clipart",
            "Clipart.tgz",
            "https://cloud.tsinghua.edu.cn/f/0d41e7da4558408ea5aa/?dl=1",
        ),
        (
            "Product",
            "Product.tgz",
            "https://cloud.tsinghua.edu.cn/f/76186deacd7c4fa0a679/?dl=1",
        ),
        (
            "Real_World",
            "Real_World.tgz",
            "https://cloud.tsinghua.edu.cn/f/dee961894cc64b1da1d7/?dl=1",
        ),
    ]

    image_list = {
        "Ar": "image_list/Art.txt",
        "Cl": "image_list/Clipart.txt",
        "Pr": "image_list/Product.txt",
        "Rw": "image_list/Real_World.txt",
    }
    CLASSES = [
        "Drill",
        "Exit_Sign",
        "Bottle",
        "Glasses",
        "Computer",
        "File_Cabinet",
        "Shelf",
        "Toys",
        "Sink",
        "Laptop",
        "Kettle",
        "Folder",
        "Keyboard",
        "Flipflops",
        "Pencil",
        "Bed",
        "Hammer",
        "ToothBrush",
        "Couch",
        "Bike",
        "Postit_Notes",
        "Mug",
        "Webcam",
        "Desk_Lamp",
        "Telephone",
        "Helmet",
        "Mouse",
        "Pen",
        "Monitor",
        "Mop",
        "Sneakers",
        "Notebook",
        "Backpack",
        "Alarm_Clock",
        "Push_Pin",
        "Paper_Clip",
        "Batteries",
        "Radio",
        "Fan",
        "Ruler",
        "Pan",
        "Screwdriver",
        "Trash_Can",
        "Printer",
        "Speaker",
        "Eraser",
        "Bucket",
        "Chair",
        "Calendar",
        "Calculator",
        "Flowers",
        "Lamp_Shade",
        "Spoon",
        "Candles",
        "Clipboards",
        "Scissors",
        "TV",
        "Curtains",
        "Fork",
        "Soda",
        "Table",
        "Knives",
        "Oven",
        "Refrigerator",
        "Marker",
    ]

    def __init__(self, root: str, task: str, **kwargs):
        assert task in self.image_list
        root = os.path.join(root, "officehome")
        data_list_file = os.path.join(root, self.image_list[task])

        list(
            map(
                lambda file_name: check_exits(root, file_name[0]),
                self.download_list,
            )
        )

        super(OfficeHome, self).__init__(
            root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs
        )
