import numpy as np 
import torch 
import torch.nn as nn 
import multiprocessing 

import tqdm 
# from .utils import tree_to_newick, tree_to_json, np_pruning, process_chunk_newick 

def tree_to_json(trees, idx):
    left, right, left_branch, right_branch = trees.children[idx].t().long()
    tree_json = {node:None for node in range(trees.num_nodes)}
    for node in trees.internal_nodes.cpu().numpy():
        tree_json[node] = [{
            "node": left[node].item(), "branch": left_branch[node].item()
        },
        {
            "node": right[node].item(), "branch": right_branch[node].item()
        }]
    return tree_json

def tree_to_newick(children, num_leaves, root): 
    newick_rep = list() 
    children = children.astype(int) 
    is_leaf = lambda node: node < num_leaves  
    def search_leaf(node):
        nonlocal newick_rep  
        if is_leaf(node): 
            return 

        left, right = children[node, :2]  

        newick_rep.append('(')  

        if is_leaf(left): 
            newick_rep.append(str(left.item()))
            newick_rep.append(',')  

        if is_leaf(right): 
            newick_rep.append(str(right.item()))     
            if not is_leaf(left): 
                newick_rep.append(',') 

        search_leaf(left)         
        search_leaf(right)      
     
        newick_rep.append(')')  

    search_leaf(root)  
    
    return ''.join(newick_rep) + ';' 

def np_pruning(Q, pi, trees, sites, vocab_size=4):
    # `nb_nucleotides` stands for number of nucleotides, I guess
    nb_nucleotides, n_species = sites.shape # Same for all elements in the batch
    observation = sites.clone() # Same for all elements in the batch
    x_index = torch.arange(nb_nucleotides)  # Same for all elements in the batch

    dynamic_probas = torch.zeros((trees.num_nodes, trees.batch_size, nb_nucleotides, vocab_size))

    def posterior_proba_leaf(leaf):
        likelihood = torch.zeros((trees.batch_size, nb_nucleotides, vocab_size))
        # Same for all elements in the batch
        likelihood[:, x_index, observation[:, leaf].long()] = 1
        return likelihood

    # Initialize the likelihoods
    for leaf in trees.leaves:
        dynamic_probas[leaf] = posterior_proba_leaf(leaf)

    def posterior_proba(nodes, left_children, right_children, left_branch, right_branch):
        if (dynamic_probas[nodes, trees.batch_ids].sum(dim=-1) != 0).all():
            return dynamic_probas[nodes, trees.batch_ids]
        left_children, right_children, left_branch, right_branch = trees.get_children(nodes,
                    left_children, right_children, left_branch, right_branch)

        post_proba_left = posterior_proba(left_children.long(), left_children, right_children, left_branch, right_branch)
        post_proba_right = posterior_proba(right_children.long(), left_children, right_children, left_branch, right_branch)

        prob_matrix_left = torch.matrix_exp(left_branch.view(-1, 1, 1) * Q[None, ...])
        prob_matrix_right = torch.matrix_exp(right_branch.view(-1, 1, 1) * Q[None, ...])

        left_likelihood = torch.bmm(prob_matrix_left, torch.transpose(post_proba_left, 1, 2))
        right_likelihood = torch.bmm(prob_matrix_right, torch.transpose(post_proba_right, 1, 2))

        return torch.transpose(left_likelihood * right_likelihood, 1, 2)

    for node in trees.internal_nodes:
        left_children, right_children, left_branch, right_branch = None, None, None, None
        nodes = node * torch.ones((trees.batch_size,))
        dynamic_probas[nodes.long(), trees.batch_ids] = posterior_proba(nodes.long(), left_children, right_children, left_branch, right_branch)

    return dynamic_probas[trees.root] @ pi

def process_chunk_newick(chunk, newick_rep, children, num_leaves, root): 
    for idx in chunk: 
        newick_rep.append(
            (idx.item(), tree_to_newick(children[idx], num_leaves, root)) 
        )

class LikelihoodReward(nn.Module): 

    def __init__(self, pi=None, sites=None, vocab_size=None, mu=0., std=1.): 
        super(LikelihoodReward, self).__init__() 
        self.pi = pi 
        self.sites = sites
        self.vocab_size = vocab_size  

        # The conversion rate matrix 
        self.Q = torch.ones((vocab_size, vocab_size)) 
        ids = torch.arange(vocab_size) 
        self.Q[ids, ids] -= self.Q.sum(dim=1)      
        self.mu = mu 
        self.std = std 

        # Convert the relevant quantities to parameters 
        self.pi = nn.Parameter(torch.tensor(self.pi), requires_grad=False) 
        self.sites = nn.Parameter(torch.tensor(self.sites), requires_grad=False) 
        self.Q = nn.Parameter(self.Q, requires_grad=False) 
        self.mu = nn.Parameter(torch.tensor(self.mu), requires_grad=False) 
        self.std = nn.Parameter(torch.tensor(self.std), requires_grad=False) 

    @torch.no_grad() 
    def forward(self, trees): 
        likelihood = np_pruning(self.Q.data, self.pi.data, trees, self.sites.data, self.vocab_size) 
        loglikelihood = torch.log(likelihood).sum(dim=1) 
        return (loglikelihood - self.mu.data) / self.std.data  

class Trees: 

    def __init__(self, num_leaves, batch_size, log_reward=None): 
        self.num_leaves = num_leaves 
        self.num_nodes = 2 * self.num_leaves - 1 
        self.num_internal_nodes = self.num_leaves - 1 

        self.batch_size = batch_size 
        self.batch_ids = torch.arange(self.batch_size) 

        self._log_reward = log_reward  
    
        # Children 
        self.children = torch.zeros((self.batch_size, self.num_nodes, 4)) 
        self.parents = torch.zeros((self.batch_size, self.num_nodes)) 
    
        # The index of the root 
        self.root = self.num_nodes - 1 

        # Attributes for the generative process 

        # Dynamically changing actions 
        self.actions = torch.triu_indices(self.num_leaves, self.num_leaves, offset=1).t().expand(self.batch_size, -1, 2) 
        self.mask = torch.ones((self.actions.shape[0], self.actions.shape[1])) 

        # The nodes' features 
        self.X = torch.zeros((self.num_nodes, self.num_leaves + 2)) 
        self.X[
            :self.num_leaves, :self.num_leaves 
        ] = torch.eye(self.num_leaves) 

        self.X[
            self.num_leaves:-1, self.num_leaves 
        ] = 1. 

        self.X[
            -1, self.num_leaves + 1 
        ] = 1. 

        # The ID of the nextly included node 
        self.next_node = self.num_leaves 

        # A default value for the branches
        self.branch_length = torch.ones((self.batch_size,)) 

        # An edge list for using GNNs 
        self.edge_list = torch.zeros((0, 3), dtype=torch.long) 

        self.stopped = torch.zeros((self.batch_size,)) 

        self.num_parents = torch.zeros((self.batch_size,)) 
    
    @torch.no_grad() 
    def apply(self, indices): 
        actions = self.actions[self.batch_ids, indices] 

        left, right = actions[:, 0], actions[:, 1] 

        adj = self.adjacency_matrix

        # Update the actions 
        self.actions = torch.where(self.actions == left.view(-1, 1, 1).expand(*self.actions.shape), self.next_node, self.actions) 
        self.mask = torch.where(
            (self.actions == right.view(-1, 1, 1).expand(*self.actions.shape)).any(dim=-1), 0., self.mask) 
        self.actions = torch.where(self.actions == right.view(-1, 1, 1).expand(*self.actions.shape), self.next_node, self.actions) 
    
        # Update the states
        self.children[self.batch_ids, self.next_node] = torch.vstack([left, right, self.branch_length, self.branch_length]).t() 
        self.parents[self.batch_ids, left] = self.next_node 
        self.parents[self.batch_ids, right] = self.next_node 
        
        # Update the list of edges 
        next_node_vec = torch.ones((self.batch_size,), dtype=torch.long) * self.next_node 
        edges_to_append = torch.vstack([
            torch.vstack([self.batch_ids, left, next_node_vec]).t(), 
            torch.vstack([self.batch_ids, right, next_node_vec]).t() 
        ])
        self.edge_list = torch.vstack([self.edge_list, edges_to_append.long()]) 

        # Update the label of the next node         
        self.stopped += (self.next_node == self.root) # The last action is deterministic 
        
        self.next_node += 1 
        
        self.num_parents += ((left < self.num_leaves) & (right < self.num_leaves)).long() 
        self.num_parents -= ((left >= self.num_leaves) & (right >= self.num_leaves)).long() 
            
        return (self.stopped < 2.) 
 
    @torch.no_grad() 
    def get_children(self, nodes, left_children_curr=None, right_children_curr=None, left_branch_curr=None, right_branch_curr=None): 
        right_children = self.children[self.batch_ids, nodes, 1] 
        left_children = self.children[self.batch_ids, nodes, 0] 

        right_branch = self.children[self.batch_ids, nodes, 3] 
        left_branch = self.children[self.batch_ids, nodes, 2] 
        has_children = ((right_children + left_children) > 0).to(dtype=torch.get_default_dtype()) 
        if left_children_curr is not None and right_children_curr is not None: 
            return left_children * has_children + (1 - has_children) * left_children_curr, \
                right_children * has_children + (1 - has_children) * right_children_curr, \
                left_branch * has_children + (1 - has_children) * left_branch_curr, \
                right_branch * has_children + (1 - has_children) * right_branch_curr 
        else: 
            return left_children, right_children, left_branch, right_branch 
    
    @property 
    @torch.no_grad() 
    def internal_nodes(self): 
        return torch.arange(self.num_internal_nodes) + self.num_leaves 

    @property 
    @torch.no_grad() 
    def leaves(self): 
        return torch.arange(self.num_leaves) 

    @torch.no_grad() 
    def is_leaf(self, nodes): 
        return (nodes < self.num_leaves)   

    @torch.no_grad() 
    def to_newick(self): 
        manager = multiprocessing.Manager() 
        newick_rep = manager.list() 

        num_cores = multiprocessing.cpu_count() 
        chunk_size = self.batch_size // num_cores 
        chunks = [self.batch_ids[i:i + chunk_size].cpu().numpy() for i in range(0, self.batch_size, chunk_size)]

        children_np = self.children.cpu().numpy() 
        with multiprocessing.Pool(processes=num_cores) as pool: 
            pool.starmap(process_chunk_newick, 
                    [(chunk, newick_rep, children_np, self.num_leaves, self.root) for chunk in chunks])
        # for batch_idx in self.batch_ids: 
        #     newick_rep.append(
        #         tree_to_newick(self, idx=batch_idx) 
        #     )
        indices = [newick_rep[i][0] for i in range(len(newick_rep))]
        results = np.ones((len(newick_rep,)), dtype='U256') 
        results[indices] = [newick_rep[i][1] for i in range(len(newick_rep))] 
        return results 

    @torch.no_grad() 
    def to_json(self): 
        tree_json = list() 
        for batch_idx in self.batch_ids: 
            tree_json.append(
                tree_to_json(self, idx=batch_idx) 
            )
        return tree_json 

    @torch.no_grad() 
    def edge_list_t(self): 
        return ((self.edge_list[:, 0] * self.num_nodes).view(-1, 1).expand(-1, 2) + self.edge_list[:, 1:]).t().long() 

    @property 
    @torch.no_grad() 
    def expanded_data(self): 
        return self.X.reshape(1, *self.X.shape).expand(self.batch_size, self.X.shape[0], self.X.shape[1]).reshape(-1, self.X.shape[1]) 

    @torch.no_grad() 
    def floyd_warshall_batch(self): 
        dist = self.adjacency_matrix 
        dist = torch.where(dist == 0., torch.inf, dist) 
        ids = torch.arange(self.num_nodes) 
        dist[ids, ids] = 0 

        for k in range(self.num_nodes): 
            pdist = dist[:, :, k, None] + dist[:, None, k, :] 
            mask = (dist > pdist).to(dtype=torch.long) 
            dist = torch.where(mask == 1., pdist, dist) 

        leaves_root = torch.hstack([self.leaves.long(), torch.tensor(self.root, dtype=torch.long)]) 
        return torch.hstack([
            dist[:, leaves_root, i] for i in leaves_root  
        ])
    
    @property 
    def adjacency_matrix(self): 
        adj = torch.zeros((self.batch_size, self.num_nodes, self.num_nodes)) 
        batch_ids, left_node, right_node = self.edge_list.t() 
        adj[batch_ids, left_node, right_node] = 1. 
        adj[batch_ids, right_node, left_node] = 1. 
        return adj 
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        # batch size, batch ids, actions, branch length, edge list, stopped and num parents, mask 
        self.actions = torch.vstack([self.actions, batch_state.actions]) 
        self.branch_length = torch.hstack([self.branch_length, batch_state.branch_length]) 
        self.edge_list = torch.vstack([self.edge_list, batch_state.edge_list]) 
        self.stopped = torch.hstack([self.stopped, batch_state.stopped]) 
        self.num_parents = torch.hstack([self.num_parents, batch_state.num_parents]) 
        self.mask = torch.vstack([self.mask, batch_state.mask]) 

        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 

    @torch.no_grad() 
    def log_reward(self): 
        return self._log_reward(self) 
