import sys
import os
import torch

class HierarchyTree:

    def __init__(self, dataset = "cifar100", data_folder = ''):

        self.dataset = dataset
        self.data_folder = data_folder
        self.init_child_parent_pairs()
        self.init_child_parent_pairs_tensor()
        self.init_labels()
        self.init_all_labels()
        self.init_label_depth()

    def init_child_parent_pairs(self):
        fileName = self.dataset + "_child_parent_pairs.txt"
        file = os.path.join(self.data_folder, fileName)

        self.child_parent_pairs = dict()

        f = open(file, "r")

        self.npairs = int(f.readline())
        for line in f:

            # None corresponds to the root
            if('None' in line):
                temp = [ key for key in line.split(";")]
                # Should not happen
                if 'None' in temp[0]:
                    raise ValueError("first key is none !")
                temp[0] = int(temp[0])
                temp[1] = None
            else:
                temp = [ int(key) for key in line.split(";")]


            # doucble check
            if len(temp) != 2:
                raise ValueError("should be a pair")
            if int(temp[0]) in self.child_parent_pairs.keys():
                raise ValueError("Key already exist!")

            # Add the pair
            self.child_parent_pairs[temp[0]] = temp[1]
        f.close()


    def init_child_parent_pairs_tensor(self):
        self.child_parent_pairs_tensor = torch.zeros(len(self.child_parent_pairs))
        for key, value in enumerate(self.child_parent_pairs):
            if value is None:
                self.child_parent_pairs_tensor[key] = -1
            else:
                self.child_parent_pairs_tensor[key] = value
        self.child_parent_pairs_tensor = self.child_parent_pairs_tensor.long()
        self.child_parent_pairs_tensor = self.child_parent_pairs_tensor.cuda()
        #print(self.child_parent_pairs_tensor, type(self.child_parent_pairs_tensor))


    def init_labels(self):
        fileName = self.dataset + "_labels.txt"
        file = os.path.join(self.data_folder, fileName)
        self.labels = list()

        f = open(file, "r")

        self.nlabels = int(f.readline())
        for line in f:
            self.labels.append(int(line))
        f.close()
        self.n_labels = len(self.labels)

    def init_all_labels(self):
        fileName = self.dataset + "_all_labels.txt"
        file = os.path.join(self.data_folder, fileName)
        self.all_labels = list()

        f = open(file, "r")

        self.n_all_labels = int(f.readline())
        for line in f:
            self.all_labels.append(int(line))
        f.close()
        self.n_all_labels = len(self.all_labels)


    def init_label_depth(self):
        self.label_depth = dict()
        for label in self.all_labels:
            self.label_depth[label] = 0
            parent = label
            while parent in self.child_parent_pairs.keys():
                parent = self.child_parent_pairs[parent]
                self.label_depth[label] += 1

# Test function.
def main():
    test = HierarchyTree()
    for key in test.child_parent_pairs.keys():
        print(key)
        print(test.child_parent_pairs[key])

    print("_____________\n")

    for label in test.labels:
        print(label)

    print("_____________\n")

    for label in test.all_labels:
        print(label)


# main() #Launch test function on loading
