import glob
import os.path as osp

from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
from ..base_dataset import Datum, DatasetBase
from dassl.utils import listdir_nohidden


@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
    """Office-Home.

    Statistics:
        - Around 15,500 images.
        - 65 classes related to office and home objects.
        - 4 domains: Art, Clipart, Product, Real World.
        - URL: http://hemanthdv.org/OfficeHome-Dataset/.

    Reference:
        - Venkateswara et al. Deep Hashing Network for Unsupervised
        Domain Adaptation. CVPR 2017.
    """

    dataset_dir = "office_home_dg"
    domains = ["art", "clipart", "product", "real_world"]
    data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.cfg = cfg
        self.all_cls = [
            'Alarm_Clock','Backpack','Batteries','Bed','Bike','Bottle','Bucket', \
            'Calculator','Calendar','Candles','Chair','Clipboards','Computer','Couch','Curtains', \
            'Desk_Lamp', 'Drill', 'Eraser', 'Exit_Sign', 'Fan', 'File_Cabinet', 'Flipflops', 'Flowers', \
            'Folder', 'Fork', 'Glasses', 'Hammer', 'Helmet', 'Kettle', 'Keyboard', \
            'Knives', 'Lamp_Shade', 'Laptop', 'Marker', 'Monitor', 'Mop', 'Mouse', 'Mug', \
            'Notebook', 'Oven', 'Pan', 'Paper_Clip', 'Pen', 'Pencil', 'Postit_Notes', \
            'Printer', 'Push_Pin', 'Radio', 'Refrigerator', 'Ruler', 'Scissors', 'Screwdriver', 'Shelf', \
            'Sink', 'Sneakers', 'Soda', 'Speaker', 'Spoon', 'Table', 'Telephone', \
            'ToothBrush', 'Toys', 'Trash_Can', 'TV', 'Webcam' \
            ]

        self.more_test_cls = [
            'Desk_Lamp', 'Drill', 'Eraser', 'Exit_Sign', 'Fan', 'File_Cabinet', 'Flipflops', 'Flowers', \
            'Folder', 'Fork', 'Glasses', 'Hammer', 'Helmet', 'Kettle', 'Keyboard', \
            'Knives', 'Lamp_Shade', 'Laptop', 'Marker', 'Monitor', 'Mop', 'Mouse', 'Mug', \
            'Notebook', 'Oven', 'Pan', 'Paper_Clip', 'Pen', 'Pencil', 'Postit_Notes', \
            'Printer', 'Push_Pin', 'Radio', 'Refrigerator', 'Ruler', 'Scissors', 'Screwdriver', 'Shelf', \
            'Sink', 'Sneakers', 'Soda', 'Speaker', 'Spoon', 'Table', 'Telephone', \
            'ToothBrush', 'Toys', 'Trash_Can', 'TV', 'Webcam' \
            ]
        # self.more_test_cls = []#for closeset
        self.train_cls = self._delete_cls(self.all_cls.copy())

        # for figure:
        # self.all_cls = [
        #     'Alarm_Clock','Backpack','Batteries','Bed','Bike','Bottle','Bucket', \
        #     'Calculator','Calendar','Candles','Chair','Clipboards','Computer','Couch','Curtains', \
        #     'Desk_Lamp', 'Drill', 'Eraser', 'Exit_Sign', 'Fan', 'File_Cabinet', 'Flipflops', 'Flowers', \
        #     'Folder', 'Fork', 'Glasses', 'Hammer', 'Helmet', 'Kettle', 'Keyboard', \
        #     'Knives', 'Lamp_Shade', 'Laptop', 'Marker', 'Monitor', 'Mop', 'Mouse', 'Mug', \
        #     'Notebook', 'Oven', 'Pan', 'Paper_Clip', 'Pen', 'Pencil', 'Postit_Notes', \
        #     'Printer', 'Push_Pin', 'Radio', 'Refrigerator', 'Ruler', 'Scissors', 'Screwdriver', 'Shelf', \
        #     'Sink', 'Sneakers', 'Soda', 'Speaker', 'Spoon', 'Table', 'Telephone', \
        #     'ToothBrush', 'Toys', 'Trash_Can', 'TV', 'Webcam' \
        #     ]
        # self.more_test_cls = [
        #     'Knives', 'Lamp_Shade', 'Laptop', 'Marker', 'Monitor', 'Mop', 'Mouse', 'Mug', \
        #     'Notebook', 'Oven', 'Pan', 'Paper_Clip', 'Pen', 'Pencil', 'Postit_Notes', \
        #     'Printer', 'Push_Pin', 'Radio', 'Refrigerator', 'Ruler', 'Scissors', 'Screwdriver', 'Shelf', \
        #     'Sink', 'Sneakers', 'Soda', 'Speaker', 'Spoon', 'Table', 'Telephone', \
        #     'ToothBrush', 'Toys', 'Trash_Can', 'TV', 'Webcam' \
        #     ]
        # self.train_cls = self._delete_cls(self.all_cls.copy())


        # if not osp.exists(self.dataset_dir):
        #     dst = osp.join(root, "office_home_dg.zip")
        #     self.download_data(self.data_url, dst, from_gdrive=True)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train", self.train_cls, self.more_test_cls
        )
        val = read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val", self.train_cls, self.more_test_cls
        )
        test = read_data(
            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all", self.train_cls, self.more_test_cls
        )

        # test = train  # 用域绘图train的mean  std

        super().__init__(train_x=train, val=val, test=test)

    def _delete_cls(self, all_cls):
        for cls in self.more_test_cls:
            all_cls.remove(cls)
        return all_cls

def read_data(dataset_dir, input_domains, split, train_cls, more_test_cls):

    def _load_data_from_directory(directory):
        folders = listdir_nohidden(directory)
        folders.sort()
        items_ = []

        for label, folder in enumerate(folders):
            if folder in more_test_cls and (split=='train' or split=='val'):
                continue
            if folder in more_test_cls and (split=='all'):
                label = len(train_cls)-1 + 1
            impaths = glob.glob(osp.join(directory, folder, "*.jpg"))

            for impath in impaths:
                items_.append((impath, label))

        return items_

    items = []

    for domain, dname in enumerate(input_domains):
        if split == "all":
            train_dir = osp.join(dataset_dir, dname, "train")
            impath_label_list = _load_data_from_directory(train_dir)
            val_dir = osp.join(dataset_dir, dname, "val")
            impath_label_list += _load_data_from_directory(val_dir)
        elif split == "train":
            train_dir = osp.join(dataset_dir, dname, "train")
            impath_label_list = _load_data_from_directory(train_dir)
            val_dir = osp.join(dataset_dir, dname, "val")
            impath_label_list += _load_data_from_directory(val_dir)
        else:
            split_dir = osp.join(dataset_dir, dname, split)
            impath_label_list = _load_data_from_directory(split_dir)

        for impath, label in impath_label_list:
            impath = impath.replace('\\', '/')
            class_name = impath.split("/")[-2].lower()
            item = Datum(
                impath=impath,
                label=label,
                domain=domain,
                classname=class_name
            )
            items.append(item)

    return items