from collections import defaultdict, Counter
import numpy as np 

class CovLineage(object):
    def __init__(self, vocab, dataset) -> None:
        self.lineages = vocab
        self.lineage_to_index = {l: idx for idx, l in enumerate(vocab)}
        self.percentile = 0
        self.build_pseudo_time(dataset)
        self.build_tree_structure()
        # self.node_to_ancestor = []
        # print(vocab)
        # self.unk_idx = self.lineage_to_index[""]
        # print(self.unk_idx)
        # print("None" in self.lineage_to_index)
        # print(self.ancestor_idx)
        # print(Counter(self.ancestor_idx).most_common())
        # print(self.lineages[Counter(self.ancestor_idx).most_common()[0][0]])
        # print(self.lineages[Counter(self.ancestor_idx).most_common()[1][0]])
        # exit()

    def build_tree_structure(self, ):
        # Build the real time for each clade
        self.edges_index = []
        self.edges_length = []
        self.ancestor_idx = [-1] * len(self.lineages)
        self.to_ancestor_edges_length = [-1] * len(self.lineages)

        for lineage in self.lineage_to_index:
            split_lineage = lineage.split(".")
            for i in range(1, len(split_lineage))[::-1]:
                ancestor = ".".join(split_lineage[:i])
                if ancestor in self.lineage_to_index:
                    # self.lineages.append(ancestor)
                    # self.lineage_to_index[ancestor] = len(self.lineages) - 1
                    self.edges_index.append((self.lineage_to_index[ancestor], self.lineage_to_index[lineage]))
                    # Update the lineage time:
                    if self.lineage_time[lineage] < self.lineage_time[ancestor]:
                        self.lineage_time[ancestor] = self.lineage_time[lineage]
                    self.ancestor_idx[self.lineage_to_index[lineage]] = self.lineage_to_index[ancestor]
                    break
        for lineage_idx in range(len(self.lineages)):
            ancestor_idx = self.ancestor_idx[lineage_idx]
            if ancestor_idx != -1:
                branch_length = self.lineage_time[self.lineages[lineage_idx]] - self.lineage_time[self.lineages[ancestor_idx]]
                assert branch_length >= 0
                self.to_ancestor_edges_length[lineage_idx] = branch_length
                # self.edges_length.append(branch_length)
        # print(self.edge_length)
        # print(len(self.edge_length)) # 4867
        # print(np.sum(np.asarray(self.edge_length) < 0)) # 128
        # print(self.to_ancestor_edges_length)
        # print(self.ancestor_idx)
        # exit()

        # print(self.lineage_time["B"])
        # print(self.lineage_time["B.1"])
        # print(self.lineage_time["B.1.1.529.1"])
        # print(self.lineage_time["B.1.1.529.1.1"])
        # print(self.lineage_time["B.1.1.529.1.1.1"])

    def build_pseudo_time(self, dataset):
        lineage2times = defaultdict(list)
        for data in dataset:
            lineage = data["lineage"]
            time = data["src_time"]
            lineage2times[lineage].append(time)
        
        self.lineage_time = {}
        for lineage in lineage2times:
            lineage_time = np.percentile(np.asarray(lineage2times[lineage]), self.percentile)
            self.lineage_time[lineage] = lineage_time
        
        # print(self.lineage_time["B"])
        # print(self.lineage_time["B.1"])
        # print(self.lineage_time["B.1.1.529.1"])
        # print(self.lineage_time["B.1.1.529.1.1"])
        # print(self.lineage_time["B.1.1.529.1.1.1"])