import networkx as nx
import torch
from collections import defaultdict
from data.io import read_tree

class Tree(object):
    def __init__(self, tree_path) -> None:
        self.tree = read_tree(tree_path)
        self.all_nodes = list(self.tree.nodes())
        self.all_edges = list(self.tree.edges())
        self.leaves_nodes = [node for node in  self.tree.nodes if node.name is not None and "EPI_" in node.name]
        self.shorted_paths = nx.shortest_path(self.tree, source=self.all_nodes[0])
        self.id2nodes = {node.name: node for node in self.all_nodes}
        self.node2index = {node: i for i, node in enumerate(self.all_nodes)}

    def get_shortest_path(self, edges_index, num_nodes):
        tree = nx.DiGraph()
        tree.add_nodes_from(range(num_nodes))
        tree.add_edges_from(edges_index)
        path = nx.shortest_path_length(tree, 0)
        layers = torch.zeros(num_nodes).long()
        for i, d in path.items():
            layers[i] = d
        return layers

    def extract_subtree(self, sample_epi_ids):
        sample_nodes = [self.id2nodes[x] for x in sample_epi_ids]
        subtree_edges = []
        self._extract_subtree([self.shorted_paths[node] for node in sample_nodes], subtree_edges)
        
        subtree_nodes = []
        edges_index = []
        edges_length = []
        for pairs in subtree_edges:
            if pairs[0] not in subtree_nodes:
                subtree_nodes.append(pairs[0])
            if pairs[1] not in subtree_nodes:
                subtree_nodes.append(pairs[1])
            edges_index.append((subtree_nodes.index(pairs[0]), subtree_nodes.index(pairs[1])))
            edges_length.append(pairs[-1])
        # print(subtree_nodes)
        # print(edges_index)
        # print(edges_length)
        observed_masks = [1 if "EPI" in x.name else 0 for x in subtree_nodes]
        observed_revserse_index = [sample_nodes.index(x) for x in subtree_nodes if x in sample_nodes]
        # print(observed_revserse_index)
        # print(observed_masks)
        return {
            "node_observed_masks": observed_masks,
            "edges_index": edges_index,
            "edges_length": edges_length,
            "observed_revserse_index": observed_revserse_index,
            "layers": self.get_shortest_path(edges_index, len(observed_masks))
        }

    def _extract_subtree(self, nodes_ancestors, edges, root=None):
        # print("RUN extract_subtree")
        def _all_equal(index, list_of_list):
            for l in list_of_list:
                if index >= len(l):
                    return False
            values = [l[index] for l in list_of_list]
            for v in values[1:]:
                if v != values[0]:
                    return False
            return True
        # print(len(nodes_subset))
        # print(len(nodes_ancestors))
        # print(nodes_subset)
        # print(len(nodes_ancestors))
        if len(nodes_ancestors) == 1:
            length = 0
            for ancestors in nodes_ancestors[0]:
                length += getattr(ancestors, "branch_length") if getattr(ancestors, "branch_length") is not None else 0
            if root is not None:
                edges.append((root, nodes_ancestors[0][-1], length))
            return             
        # def _remove_redundant_edges():
        # nodes_subset: A list of nodes,
        # nodes_precessors: Dict, key = node in nodes_subset, value = list of ancestors
        i = 0
        # root_node = nodes_subset[0]
        length = 0
        # print(_all_equal(i, nodes_ancestors))
        
        # while _all_equal(i, )
        while _all_equal(i, nodes_ancestors):
            length += getattr(nodes_ancestors[0][i], "branch_length") if getattr(nodes_ancestors[0][i], "branch_length") is not None else 0
            i = i + 1
        # print(i, length)
        next_root = nodes_ancestors[0][i-1]
        # print("next_root", [next_root])
        # edges = (nodes_ancestors[0][0], nodes_ancestors[0][i])
        if root is not None:
            edges.append((root, next_root, length))
        split_subtrees = defaultdict(list)
        for ans in nodes_ancestors:
            split_subtrees[ans[i]].append(ans[i:])
        # print(split_subtrees)
        for key in split_subtrees:
            # print("?", [key], split_subtrees[key], len(split_subtrees[key]))
            self._extract_subtree(split_subtrees[key], edges, root=next_root)
        # unique_sucessors = torch.unique(nodes_precessors[i])
        # For each unique_sucessors, we should split them, and 
