import os.path as osp

from .digits_dg import DigitsDG

from Dassl.dassl.data.datasets.base_dataset import DatasetBase

# @DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
    """Office-Home.

    Statistics:
        - Around 15,500 images.
        - 65 classes related to office and home objects.
        - 4 domains: Art, Clipart, Product, Real World.
        - URL: http://hemanthdv.org/OfficeHome-Dataset/.

    Reference:
        - Venkateswara et al. Deep Hashing Network for Unsupervised
        Domain Adaptation. CVPR 2017.
    """

    dataset_dir = "office_home_dg"
    domains = ["art", "clipart", "product", "real_world"]
    data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        if not osp.exists(self.dataset_dir):
            dst = osp.join(root, "office_home_dg.zip")
            self.download_data(self.data_url, dst, from_gdrive=True)

        # self.check_input_domains(
        #     cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        # )

        train = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
        )
        val = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
        )
        test = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
        )

        # super().__init__(train_x=train, val=val, test=test)
        print(train) #[0], type(train[0]))
        1/0

        if cfg.DATASET.USEALL:
            federated_train_x = self.generate_federated_dataset(train, num_shots=cfg.DATASET.NUM_SHOTS,
                                                                num_users=cfg.DATASET.USERS,
                                                                is_iid=cfg.DATASET.IID,
                                                                repeat_rate=0)
        elif not cfg.DATASET.USEALL:
            federated_train_x = self.generate_federated_fewshot_dataset(train, num_shots=cfg.DATASET.NUM_SHOTS,
                                                                        num_users=cfg.DATASET.USERS,
                                                                        is_iid=cfg.DATASET.IID,
                                                                        repeat_rate=0)
        federated_test_x = self.generate_federated_dataset(test, num_shots=cfg.DATASET.NUM_SHOTS,
                                                            num_users=cfg.DATASET.USERS,
                                                            is_iid=cfg.DATASET.IID,
                                                            repeat_rate=0)


        super().__init__(total_train_x=train, federated_train_x=federated_train_x, federated_test_x=federated_test_x)