import os
import pickle

from collections import OrderedDict

from .negative_database_whole import ADE20K_150_NEGATIVES
# from .kitti360_labels import kitti360_labels

from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden

from .oxford_pets import OxfordPets


@DATASET_REGISTRY.register()
class ADE20K_150_Negative(DatasetBase):

    def __init__(self, cfg):
        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = "ade20k_150"
        self.dataset_dir = os.path.join(root, self.dataset_dir)
        self.image_dir = self.dataset_dir
        
        self.class_names = self.read_classnames(os.path.join(self.dataset_dir, "ade20k_150_classnames.txt"))
        
        _, label2dataset = self.get_indices("ade_gt_150cls_train_v3", "ade_gt_150cls_val", self.class_names)
        
        negative_classes = self.get_negatives(ADE20K_150_NEGATIVES)

        all_classnames = self.class_names + negative_classes
        
        train, negatives, negative_ids = self.read_data(self.class_names, "ade_gt_150cls_train_v3", label2dataset, negative_classes=ADE20K_150_NEGATIVES) # 
        test = self.read_data(self.class_names, "ade_gt_150cls_val", label2dataset) # 

        # Get the list of Datums for the New Domain (only labels no images)
        train_u = self.create_dataset_from_labels_only(negatives, negative_ids, len(all_classnames))

        subsample = cfg.DATASET.SUBSAMPLE_CLASSES
        train, train_u, test = OxfordPets.subsample_classes(train, train_u, test, subsample=subsample)

        super().__init__(train_x=train, train_u=train_u, val=test, test=test, classnames=all_classnames)


    def read_classnames(self, text_file):
        """Return a dictionary containing
        key-value pairs of <folder name>: <class name>.
        """
        classnames = []
        with open(text_file, "r") as f:
            lines = f.readlines()
            for line in lines:
                classnames.append(line.strip().split(":")[-1])

        return classnames

    def get_negatives(self, classes_dict):
        negative_classes = []
        self.prev_len = -1
        for key in classes_dict:
            negative_classes.extend(classes_dict[key])
            if self.prev_len > 0:
                assert len(classes_dict[key]) == self.prev_len
            else:
                self.prev_len = len(classes_dict[key])
        print(f"number of negative classes {len(negative_classes)}")
        return negative_classes

    def get_indices(self, train_dir, test_dir, classnames):
        train_dir = os.path.join(self.image_dir, train_dir)
        train_folders = sorted(f.name for f in os.scandir(train_dir) if f.is_dir())

        test_dir = os.path.join(self.image_dir, test_dir)
        test_folders = sorted(f.name for f in os.scandir(test_dir) if f.is_dir())

        train_indices = [int(folder) for folder in train_folders]
        test_indices = [int(folder) for folder in test_folders]

        print(f"Train indices w/ Dataset ID: {train_indices}")
        print(f"Test indices w/ Dataset ID: {test_indices}")

        train_test_classnames = []
        label2dataset = {}
        common_ids =[]
        for key, class_name in enumerate(classnames):
            if key in train_indices or key in test_indices:
                train_test_classnames.append(class_name) 
                label2dataset[key] = key
                if key in train_indices and key in test_indices:
                    common_ids.append(key)
        
        print(f"Common ids: {common_ids}")

        return train_test_classnames, label2dataset

    def read_data(self, classnames, split_dir, label2dataset, negative_classes=None):
        split_dir = os.path.join(self.image_dir, split_dir)
        folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())

        items = []

        print(f"Reading data from {split_dir}")

        classes_list = []
        if negative_classes is not None:
            negative_classes_with_same_occurrences = []
            negative_label_ids = []
        for label, folder in enumerate(folders):
            imnames = listdir_nohidden(os.path.join(split_dir, folder))
            classname = classnames[int(folder)].split(",")[0]
            classes_list.append(classname)
            if negative_classes is not None and classname in negative_classes:
                negative_classname_list = negative_classes[classname]
                len_negative =  len(negative_classname_list)
                i = 0
            elif negative_classes is not None and classname not in negative_classes:
                print(f"Negatives for {classname} doesnt exist!!!")
                continue
            for imname in imnames:
                impath = os.path.join(split_dir, folder, imname)
                if negative_classes is None:
                    item = Datum(impath=impath, label=label2dataset[int(folder)], classname=classname)
                else:
                    negative_classes_with_same_occurrences.append(negative_classname_list[i])
                    negative_label_ids.append(len_negative * label + i)
                    if i + 1 < len_negative:
                        i += 1
                    else:
                        i = 0
                    item = Datum(impath=impath, label=label2dataset[int(folder)], label_negative=-1, classname=classname)
                items.append(item)

        print(f"Classes for {split_dir}: {classes_list}")

        if negative_classes is not None:
            return items, negative_classes_with_same_occurrences, negative_label_ids
        return items
    
    def create_dataset_from_labels_only(self, classnames, class_ids, num_classes_train_val):
        items = []
        default_impath = ""

        for i, classname in enumerate(classnames):
            item = Datum(impath=default_impath, label=(num_classes_train_val + class_ids[i]), classname=classname, impath_exists=False)
            items.append(item)
        
        return items