import glob
import os.path as osp

from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
from ..base_dataset import Datum, DatasetBase
from dassl.utils import listdir_nohidden


@DATASET_REGISTRY.register()
class Office31DG(DatasetBase):
    """Office-31.

    Statistics:
        - 4,110 images.
        - 31 classes related to office objects.
        - 3 domains: Amazon, Webcam, Dslr.

    Reference:
        - Saenko et al. Adapting visual category models to
        new domains. ECCV 2010.
    """

    dataset_dir = "office31"
    domains = ["amazon", "webcam", "dslr"]


    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.cfg = cfg
        self.all_cls = [
            'back_pack', 'bike', 'calculator', 'headphones', 'keyboard', \
            'laptop_computer', 'monitor', 'mouse', 'mug', 'pen', \
            'phone', 'printer', 'projector', 'punchers', 'ring_binder', \
            'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', \
            'trash_can']

        self.more_test_cls = [
            'phone', 'printer', 'projector', 'punchers', 'ring_binder', \
            'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', \
            'trash_can']
        
        
        self.train_cls = self._delete_cls(self.all_cls.copy())


        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train", self.train_cls, self.more_test_cls
        )
        val = read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train", self.train_cls, self.more_test_cls
        )
        test = read_data(
            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all", self.train_cls, self.more_test_cls
        )


        super().__init__(train_x=train, val=val, test=test)

    def _delete_cls(self, all_cls):
        for cls in self.more_test_cls:
            all_cls.remove(cls)
        return all_cls

def read_data(dataset_dir, input_domains, split, train_cls, more_test_cls):

    def _load_data_from_directory(directory):
        folders = listdir_nohidden(directory)
        folders.sort()
        items_ = []

        for label, folder in enumerate(folders):
            if folder in more_test_cls and (split=='train' or split=='val'):
                continue
            if folder in more_test_cls and (split=='all'):
                label = len(train_cls)-1 + 1
            impaths = glob.glob(osp.join(directory, folder, "*.jpg"))

            for impath in impaths:
                items_.append((impath, label))

        return items_

    items = []

    for domain, dname in enumerate(input_domains):
        if split == "all":
            train_dir = osp.join(dataset_dir, dname)
            impath_label_list = _load_data_from_directory(train_dir)

        else:
            split_dir = osp.join(dataset_dir, dname)
            impath_label_list = _load_data_from_directory(split_dir)

        for impath, label in impath_label_list:
            impath = impath.replace('\\', '/')
            class_name = impath.split("/")[-2].lower()
            item = Datum(
                impath=impath,
                label=label,
                domain=domain,
                classname=class_name
            )
            items.append(item)

    return items