import torch 
import torch.nn as nn 
import numpy as np 

@torch.no_grad()
def np_pruning(Q, pi, trees, sites):
    vocab_size = len(pi) 
    max_recursion_steps = np.floor(np.log2(trees.num_nodes) + 1)
    max_recursion_steps = int(max_recursion_steps.item())

    likelihoods = torch.zeros((trees.batch_size, trees.num_nodes, sites.shape[0], vocab_size), device=pi.device)

    # Compute the likelihoods for the nodes
    # Since the nodes are appended sequentially to the tree,
    # this ensures that the computation of a node's likelihood will be preceded
    # by the computation of its children's likelihoods
    for idx in range(vocab_size):
        likelihoods[trees.batch_ids, :trees.num_leaves, :, idx] = (sites == idx).t().to(dtype=likelihoods.dtype)

    for idx in range(trees.num_leaves, trees.num_nodes):
        # `idx`'s children
        idx_children = trees.children[trees.batch_ids, idx]
        left_children, right_children = idx_children[:, 0].long(), idx_children[:, 1].long()
        left_likelihoods = likelihoods[trees.batch_ids, left_children]
        right_likelihoods = likelihoods[trees.batch_ids, right_children]

        left_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])
        right_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])

        marginal_left = torch.bmm(
            left_transition, torch.transpose(left_likelihoods, 1, 2)
        )
        marginal_right = torch.bmm(
            right_transition, torch.transpose(right_likelihoods, 1, 2)
        )

        likelihoods[trees.batch_ids, idx] = torch.transpose(marginal_left * marginal_right, 1, 2)

    marginal_likelihoods = likelihoods[trees.batch_ids, trees.root] @ pi
    return torch.log(marginal_likelihoods).sum(dim=-1)

@torch.no_grad() 
def compute_site_likelihood(Q, vocab_size, trees, site):
    likelihoods = torch.zeros((trees.batch_size, trees.num_nodes, vocab_size), device=trees.device)

    for idx in range(vocab_size):
        likelihoods[trees.batch_ids, :trees.num_leaves, idx] = (site == idx).t().to(dtype=likelihoods.dtype)

    for idx in range(trees.num_leaves, trees.num_nodes):
        # `idx`'s children
        idx_children = trees.children[trees.batch_ids, idx]
        left_children, right_children = idx_children[:, 0].long(), idx_children[:, 1].long()
        left_likelihoods = likelihoods[trees.batch_ids, left_children]
        right_likelihoods = likelihoods[trees.batch_ids, right_children]

        left_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])
        right_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])

        marginal_left = torch.bmm(
            left_transition, left_likelihoods.unsqueeze(-1) 
        ).squeeze(dim=-1) 
        marginal_right = torch.bmm(
            right_transition, right_likelihoods.unsqueeze(-1)
        ).squeeze(dim=-1) 

        likelihoods[trees.batch_ids, idx] = marginal_left * marginal_right

    return likelihoods 

class LogReward(nn.Module):

    def __init__(self, pi=None, sites=None, sub_rate_matrix=None, temperature=1.):
        super(LogReward, self).__init__()
        self.pi = pi
        self.data = sites 
        self.vocab_size = len(pi)  
        self.temperature = temperature 
        self.Q = sub_rate_matrix 
        
        self.shift = 0. # For numerical stability 

    @torch.no_grad()
    def forward(self, trees):
        loglikelihood = np_pruning(self.Q, self.pi, trees, self.data)
        # loglikelihood = torch.log(likelihood).sum(dim=1)
        return (loglikelihood - self.shift) / self.temperature 
