import os.path as osp
import random

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class PACS(DatasetBase):
    """PACS.

    Statistics:
        - 4 domains: Photo (1,670), Art (2,048), Cartoon
        (2,344), Sketch (3,929).
        - 7 categories: dog, elephant, giraffe, guitar, horse,
        house and person.

    Reference:
        - Li et al. Deeper, broader and artier domain generalization.
        ICCV 2017.
    """

    dataset_dir = "pacs"
    domains = ["art_painting", "cartoon", "photo", "sketch"]
    _error_paths = ["sketch/dog/n02103406_4068-1.png"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.image_dir = osp.join(self.dataset_dir, "images")
        self.split_dir = osp.join(self.dataset_dir, "splits")

        self.cfg = cfg
        self.all_cls = ['dog','elephant','giraffe','guitar','horse','house','person']
        self.more_test_cls = ['horse','house','person'] # OOD classes
        self.train_cls = self._delete_cls(self.all_cls.copy()) # ID classes


        print('The num of classes for training is::::::::')
        print(len(self.train_cls))


        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "all")
        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")


        super().__init__(train_x=train, val=val, test=test)

    def _delete_cls(self, all_cls):
        for cls in self.more_test_cls:
            all_cls.remove(cls)
        return all_cls

    def _read_data(self, input_domains, split):
        items = []

        for domain, dname in enumerate(input_domains):
            if split == "all":
                file_train = osp.join(
                    self.split_dir, dname + "_train_kfold.txt"
                )
                impath_label_list = self._read_split_pacs(file_train, "train")
                file_val = osp.join(
                    self.split_dir, dname + "_crossval_kfold.txt"
                )
                impath_label_list += self._read_split_pacs(file_val, "crossval")
            else:
                file = osp.join(
                    self.split_dir, dname + "_" + split + "_kfold.txt"
                )
                impath_label_list = self._read_split_pacs(file, split)

            for impath, label in impath_label_list:
                classname = impath.split("/")[-2]
                item = Datum(
                    impath=impath,
                    label=label,
                    domain=domain,
                    classname=classname
                )
                items.append(item)

        return items

    def _read_split_pacs(self, split_file, split):
        items = []

        with open(split_file, "r") as f:
            lines = f.readlines()

            for line in lines:

                flag = 0
                for cls in self.more_test_cls:
                    if cls in line and split!='test':
                        flag=1
                        break

                if flag == 0:
                    line = line.strip()
                    impath, label = line.split(" ")
                    cls = impath.split('/')[1]
                    if impath in self._error_paths:
                        continue
                    impath = osp.join(self.image_dir, impath)

                    assert cls in self.all_cls

                    if cls in self.more_test_cls: 
                        label = len(self.train_cls)-1 + 1
                    else:
                        label = self.train_cls.index(cls)

                    items.append((impath, label))

        return items
