import torch
import torch.nn as nn

class TreeNode(object):
    def __init__(self, name=None, branch_length=None) -> None:
        super().__init__()
        self.name = name
        self.ancestor = None
        self.sons = []
        self.branch_length = branch_length
        # self.message = None # [seq len x vocab size]
        # self.max_message = None # [seq len x vocab size]
        # self.argmax_message = None # [seq len x V], when ancestor is xi, the max index for its son is xj
        # self.loglikelihood = torch.zeros(vocab_size)
        # self.vocab_size = vocab_size
        # self.eps = 1e-20
        # self.infer = None
    
    def set_ancestor(self, node):
        self.ancestor = node
    
    def add_son(self, node):
        self.sons.append(node)

    def init_message(self, vocab_size, value, device):
        message = F.one_hot(value, vocab_size).to(device)
        self.message = (message + self.eps).log()
        # self.infer = value # For leaf nodes

    def set_value(self, idx):
        self.loglikelihood[idx] = 1.0
    
    def init_infer(self, ):
        son_infers = []
        for son in self.sons:
            if son.infer is None:
                son.init_infer()
            son_infers.append(son.infer)
        son_infers = torch.stack(son_infers, dim=0).T # [L x N]
        rand_prob = torch.rand(son_infers.size()) # [L x N]
        rand_sample = torch.multinomial(rand_prob, 1) # L x 1
        self.infer = torch.gather(son_infers, 1, rand_sample).squeeze(-1) # L 

    def gibbs_sampling(self, trans_model, prior=None):
        if self.ancestor is not None:
            # From ancestor
            trans_matrix = trans_model.get_transition_matrix(self.branch_length) 
            prob = trans_matrix[self.ancestor.infer].log() # L x V
        else:
            assert prior is not None # [L x V]
            prob = prior

        for son in self.sons:
            trans_matrix = trans_model.get_transition_matrix(son.branch_length) 
            prob += trans_matrix.T[son.infer].log()
        
        sample = torch.multinomial(F.softmax(prob, dim=-1), 1).squeeze(-1) # L
        self.infer = sample

    def maximum_post(self, trans_model):
        # print("MP at", self.name)
        log_prob = 0.0
        for son in self.sons:
            if son.message is None:
                son.maximum_post(trans_model)
            trans_matrix = trans_model.get_transition_matrix(son.branch_length) # torch.linalg.matrix_exp(rate_matrix * )
            print(trans_matrix)
            # p(xj|xi) = T_{ij}, s_j is the son node
            # message: L x 1 x V
            # trans_matrix: 1 x V x V
            # logMij: [L x V x V]
            # print(trans_matrix.sum(-1))

            logMij = son.message.unsqueeze(1) + trans_matrix.log().unsqueeze(0) # M_{son -> current}
            # print(logMij)
            # print(logMij.size())
            topv, topi = logMij.topk(1, dim=-1) # L, V
            # son.max_message = topv # when ancestor is i, the current state is set to argmax_message[i], and the corresponding state is max_message[i]
            son.argmax_message = topi.squeeze(-1) # 
            # print(topv.size(), topi.size())
            # print(son.argmax_message)
            
            log_prob += topv.squeeze(-1)

        self.message = log_prob
        
    def trace_back(self,):
        for son in self.sons:
            if son.argmax_message is not None:
                # self.infer: L x 1
                # son.argmax_message: L x V
                # print(son.argmax_message.size())
                # print(self.infer.size())
                son.infer = torch.gather(son.argmax_message, 1, self.infer) # L x 1
                son.trace_back()
    
    def message_passing_rev(self, trans_model, prior):
        parent_node = self.ancestor
        if parent_node is not None: # Root node, do nothing.
            message_from_ancestor_to_brothers = self.message.new_zeros(self.message.size())
            for son in parent_node.sons:
                # print(son, self)
                if son != self:
                    trans_matrix = trans_model.get_transition_matrix(son.branch_length)
                    message_from_ancestor_to_brothers += torch.mm(son.message.exp(), trans_matrix.T).log() # [L, V]
            
            self.message_rev = message_from_ancestor_to_brothers
            # print("message_from_ancestor_to_brothers", message_from_ancestor_to_brothers.exp())
            # print(self.message_rev.size(), self.message.size())

            grandma_node = parent_node.ancestor
            if grandma_node is not None:
                trans_matrix = trans_model.get_transition_matrix(parent_node.branch_length)
                message_from_grandma_to_mum = torch.mm(parent_node.message_rev.exp(), trans_matrix).log() # [L, V]
                self.message_rev = message_from_grandma_to_mum + self.message_rev
                # print("message_from_grandma_to_mum", message_from_grandma_to_mum.exp())
            else: # Parent is root
                self.message_rev += prior

            # print("rev", self.name, self.message_rev.exp())

        for son in self.sons:
            son.message_passing_rev(trans_model, prior)

    def message_passing(self, trans_model):
        # global counting
        # counting += 1
        # if counting % 10000 == 0:
        #     print(counting)
        
        log_prob = 0.0
        if len(self.sons) == 0: # Leaves: do nothing!
            return 0
        tot_branches_num = 0
        for son in self.sons:
            # if son.message is None:
            branches_num = son.message_passing(trans_model)
            tot_branches_num += (branches_num + 1)
            trans_matrix = trans_model.get_transition_matrix(son.branch_length) # torch.linalg.matrix_exp(rate_matrix * )
            # p(xj|xi) = T_{ij}, s_j is the son node
            log_prob += (torch.mm(son.message.exp(), trans_matrix.T) + self.eps).log()
        self.message = log_prob
        # print("Message passing:")
        # print(self.name, self.message.exp())
        return tot_branches_num

class TreeModel(nn.Module):
    def __init__(self, root, leaves_set) -> None:
        super().__init__()
        # self.vocab = vocab
        # self.model = model
        # self.eps = 1e-15
        # , prior=None, id2tensor=None
        # self.prior = prior
        # self.vocab_size = vocab_size
        # self.max_transition_step = max_transition_step
        # self.delta_t = delta_t
        # Q = self._build_posterior(vocab_size).cuda()
        # self.stationary_dist = Q.new_zeros(Q.size(-1)) + (1/Q.size(-1))
        # self.Qpows = self._precompute_transition_kernel(Q) # [max_transition_step, V, V]
        # self.mode = mode
        # assert mode in ("hard", "soft", "half-soft")
        # # This is our forward model: P!!
        # P = torch.randn(*Q.size(), requires_grad=True, device=Q.device)
        # P = F.softmax(P, dim=-1)
        # self.P = nn.Parameter(P)
        # # self.device = torch.cuda.device(device)

        self.root = root
        self.leaves_set = leaves_set
        # self.id2tensor = id2tensor # Leave

        # if use_cuda:
        #     self.device = torch.device('cuda:0')
        # else:
        #     self.device = torch.device('cpu')

            
    def remove_leaf_node(self, node):
        assert len(node.sons) == 0, "Make sure the node you want to remove is the leaf node."
        ancestor = node.ancestor
        ancestor.sons = [s for s in ancestor.sons if s != node]
        if node.name in self.leaves_set:
            del self.leaves_set[node.name]
        del node
        if len(ancestor.sons) == 0:
            self.remove_leaf_node(ancestor)
    
    @classmethod
    def build_from_existing_tree(cls, tree, root_node, model, vocab, id2tensor, **kwargs):
        # root_node = deepcopy(root_node)
        root_node.ancestor = None
        # for son in root_node.sons:
            # son.ancestor = root_node
        return cls(model, root_node, {key: tree.leaves_set[key] for key in root_node.leaves_set}, vocab, prior=tree.prior, id2tensor=id2tensor, **kwargs)

    @classmethod
    def build_from_newick(cls, clades, **kwargs):
        leaves_set = {}
        def add_nodes(clades):
            if len(clades) == 0: # reach leaf
                leaf_node = TreeNode(branch_length=clades.branch_length, name=clades.name)
                leaves_set[clades.name] = leaf_node
                return leaf_node
            node = TreeNode(branch_length=clades.branch_length)
            name = []
            for clade in clades:
                son_node = add_nodes(clade)
                son_node.set_ancestor(node)
                name.append(son_node.name)
                node.add_son(son_node)
            if clades.name:
                node.name = clades.name
            else:
                node.name = "(" + ",".join(name) + ")"
            return node
        root_node = add_nodes(clades)
        return cls(root_node, leaves_set, **kwargs)

    @classmethod
    def build_example(cls, model, vocab, **kwargs):
        root = TreeNode(len(vocab), branch_length=None, name="0")
        x1 = TreeNode(len(vocab), branch_length=0.1, name="1")
        root.add_son(x1)
        x1.set_ancestor(root)

        x2 = TreeNode(len(vocab), branch_length=0.1, name="2")
        root.add_son(x2)
        x2.set_ancestor(root)

        x3 = TreeNode(len(vocab), branch_length=0.1, name="3")
        x1.add_son(x3)
        x3.set_ancestor(x1)

        x4 = TreeNode(len(vocab), branch_length=0.1, name="4")
        x1.add_son(x4)
        x4.set_ancestor(x1)

        leaves_set = {"2": x2, "3": x3, "4": x4}

        return cls(model, root, leaves_set, vocab, **kwargs)

    def clean_message(self, root):
        if len(root.sons) == 0: # Do nothing for leaves
            return
        root.message = None
        root.argmax_message = None
        for son in root.sons:
            self.clean_message(son)

    def prepare_for_gibbs_sampling(self, id2tokens):
        with torch.no_grad():
            prior = self.background_prior(len(self.vocab), list(id2tokens.values())).log()
            self.prior = prior.to(self.device)

            for leaf_id, leaf_node in self.leaves_set.items():
                leaf_node.init_message(id2tokens[leaf_id], self.device)
            
            self.root.init_infer()

    def gibbs_sampling_posterior(self, root):
        with torch.no_grad():
            if len(root.sons) == 0: # No need to infer for leaves
                return
            if root == self.root:
                root.gibbs_sampling(self.model, self.prior)
            else:
                root.gibbs_sampling(self.model)
            for son in root.sons:
                self.gibbs_sampling_posterior(son)

    def calc_prior(self, datasets):
        # TODO: we could learn the prior/model, by traditional methods first.
        # (each site could have their own models.)
        prior = self.background_prior(len(self.vocab), list(id2tokens.values())).log()
        pass

    def prepare_for_em(self, vocab, model, id2tokens, prior=None):
        # We need: 
        # (1) Observations: sequences in the leaves -- id2tokens, 
        # (2) substitution models: prior + model;
        if prior is None:
            prior = (torch.ones(1, len(vocab)) / len(vocab)).log() # uniform
        self.model = model
        for leaf_id, leaf_node in list(self.leaves_set.items()):
            if leaf_id in id2tokens:
                leaf_node.init_message(id2tokens[leaf_id])
            else:
                # couldn't find the sequences for this node 
                self.remove_leaf_node(leaf_node)
                        
    def calc_branch_joint_dist(self, node, branches_collect):
        if node.ancestor is not None: # Not root
            D = node.message_rev.exp() # [L, V] D(ancestor = a)
            trans_matrix = self.model.get_transition_matrix(node.branch_length) # P(ancestor=a, node=b)
            U = node.message.exp() # [L, V] U(node=b)
            prob = D.unsqueeze(-1) * trans_matrix.unsqueeze(0) * U.unsqueeze(1) # [L, V, V]
            branches_collect.append((prob, (node.name, node.ancestor.name, node.branch_length)))
        
        for son in node.sons:
            self.calc_branch_joint_dist(son, branches_collect)
        
    def infer_posterior(self, ):
        with torch.no_grad():
            self.root.message_passing(self.model)
            self.root.message_passing_rev(self.model, self.prior)

            branches_collect = []
            self.calc_branch_joint_dist(self.root, branches_collect)
            
            ll = ((self.root.message + self.prior).exp().sum(-1)) # .log() #  P(O)
            # print(ll)
            # print(self.prior.exp())
            # print(self.root.message.exp())
            # print((self.root.message + self.prior).exp())
            # exit()
            cond_dists = []
            for joint_dist, (a, b, l) in branches_collect:
                print("abl", a,b,l)
                # print(joint_dist.size())
                # print(joint_dist)
                # cond_dist = joint_dist.log() - ll
                cond_dist = joint_dist / ll.unsqueeze(-1).unsqueeze(-1)
                # print(ll)
                # print(cond_dist)
                # print(cond_dist.sum(-1).sum(-1)) # SHOULD BE ONE!!!
#                 print(cond_dist.sum(-1).sum(-1).mean())
                cond_dists.append((cond_dist, (a,b,l)))
#                 assert torch.allclose(cond_dist.sum(-1).sum(-1).mean(), torch.ones(1), atol = 1e-2)
                # exit()
                        
            return cond_dists

    def maximum_posterior(self):
        with torch.no_grad():
            self.root.maximum_post(self.model)

            max_logll, max_index = (self.root.message + self.prior).topk(1, dim=-1)
            # We need to track back
            self.root.infer = max_index
            self.root.trace_back()

            return max_logll
        
    def collect_braches(self, root, dataset):
        for son in root.sons:
            # print(root.name, son.name)
            # print(root.infer, son.infer)
            dataset.append((root.infer, son.infer, son.branch_length))
            self.collect_braches(son, dataset)

    def count_leaves_num(self, root):
        if len(root.sons) == 0:
            root.leaves_num = 1
            root.leaves_set = {root.name: root}
            return 1
        root.leaves_num = 0
        root.leaves_set = set()
        for son in root.sons:
            root.leaves_num += self.count_leaves_num(son)
            root.leaves_set.update(son.leaves_set)
        return root.leaves_num

    def split_tree(self, root, threshold, root_collect):
        if root.leaves_num <= threshold:
            root_collect.append(root)
        else:   
            for son in root.sons:
                self.split_tree(son, threshold, root_collect)

    def prepare_for_marginal_likelihood(self, id2tokens):
        # Estimate the prior distribution:
        prior = self.background_prior(len(self.vocab), list(id2tokens.values()))  #  + self.eps).log()
        prior = prior.to(self.device)
        self.prior = prior

        # Intitialize the message for leaves nodes
        for leaf_id, leaf_node in self.leaves_set.items():
            leaf_node.init_message(id2tokens[leaf_id], self.device)

    def marginal_log_likelihood(self, root):
        # print(self.root.message)
        # print(prior)
        # print(self.root.message.size(), prior.size())
        # print(self.root.message.max(-1), self.root.message.argmax(-1))
        # print(prior.max(-1), prior.argmax(-1))
        # prob = self.root.message.exp()
        # print(prob.size(), prior.size())
        # print(self.root.message.exp())

        if root.leaves_num > 1:
            branches_num = root.message_passing(self.model)
            ll = ((root.message.exp() * self.prior).sum(-1)).log()
            return ll, (branches_num+1)
        else:
            return 0, 0
            # ll_all = torch.stack(ll_all, dim=0) # [sub-tree, ]
    
    def background_prior(self, vocab_size, token_list):
        seq_len = token_list[0].size(0)
        background_prior = torch.zeros(seq_len, vocab_size) + self.eps
        # print(background_prior.size())
        for tokens in token_list:
            background_prior += F.one_hot(tokens, vocab_size)
        
        # print(background_prior)
        # for i in range(background_prior.size(0)):
        #     print(i, background_prior[i].detach().numpy().tolist())

        return background_prior / background_prior.sum(-1, keepdim=True)


if __name__ == "__main__":
    pass