import torch 
import torch.nn as nn 

import torch_geometric as pyg 

from gfn.utils import ForwardPolicyMeta, BaseNN, GammaFuncMeta  

class GammaFuncDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        log_double_fact = lambda n: 0 if (n == 1).all() else torch.log(2 * n - 3) + log_double_fact(n - 1) 
        return log_double_fact( batch_state_t.num_remaining_leaves )   
        # ones = torch.ones((batch_state_t.batch_size,), device=batch_state_t.device) 
        # return (
        #     ones * batch_state_tp1.next_node 
        # )

class GammaFuncInvDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        log_double_fact = lambda n: 0 if (n == 1).all() else torch.log(2 * n - 3) + log_double_fact(n - 1)  
        return - log_double_fact( batch_state_t.num_remaining_leaves ) 
        # ones = torch.ones((batch_state_t.batch_size,), device=batch_state_t.device) 
        # return (
        #     batch_state_t.num_internal_nodes - ones * (batch_state_tp1.next_node - batch_state_tp1.num_leaves)
        # )
        
class ForwardPolicy(ForwardPolicyMeta):

    def __init__(self, hidden_dim, num_leaves, eps=.3, device='cpu'):
        super(ForwardPolicy, self).__init__(eps=eps, device=device)
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.device = device 
        self.gcn = pyg.nn.SGConv(in_channels=num_leaves+2, out_channels=hidden_dim, K=3).to(self.device) 
        self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), 
                                 nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1) // 2))).to(self.device) 
        self.mlp_flows = nn.Linear(hidden_dim, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        nodes_emb = self.gcn(batch_state.expanded_data, batch_state.edge_list_t())  
        nodes_emb = nodes_emb.reshape(batch_state.batch_size, -1, self.hidden_dim) 
        graph_emb = nodes_emb.sum(dim=1) 
        return graph_emb 
    
    def get_pol(self, latent_emb, mask): 
        logits = self.mlp(latent_emb) 
        logits = logits * mask + (1 - mask) * self.masked_value 
        pol = logits.softmax(dim=-1) 
        gflows = self.mlp_flows(latent_emb).squeeze(dim=-1)  
        return pol, gflows 

class LearnableGamma(GammaFuncMeta): 

    def __init__(self, num_leaves, hidden_dim, total_iters=1, device='cpu'): 
        super(LearnableGamma, self).__init__(total_iters=total_iters) 
        self.num_nodes = 2 * num_leaves - 1 
        self.input_dim = 2 * self.num_nodes * self.num_nodes 
        self.hidden_dim = hidden_dim 
        self.device = device 

        self.mlp = BaseNN(self.input_dim, self.hidden_dim, 2, 1).to(self.device) 
    
    def weight_func(self, batch_state_t, batch_state_tp1): 
        return self.mlp(
            torch.hstack([
                batch_state_t.adjacency_matrix.flatten(start_dim=1), 
                batch_state_tp1.adjacency_matrix.flatten(start_dim=1)
            ])
        ).squeeze(dim=1) 
        
class ForwardPolicyMLP(ForwardPolicyMeta):

    def __init__(self, hidden_dim, num_leaves, eps=.3, device='cpu'):
        super(ForwardPolicyMLP, self).__init__(eps=eps)
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.device = device 
        self.num_nodes = 2 * self.num_leaves - 1 
        self.mlp = BaseNN(self.num_nodes * self.num_nodes, hidden_dim, 2).to(self.device) 
        self.mlp_logit = nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1) // 2)).to(self.device) 
        self.mlp_flows = nn.Linear(hidden_dim, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        return self.mlp(batch_state.adjacency_matrix.flatten(start_dim=1)) 
    
    def get_pol(self, latent_emb, mask): 
        logits = self.mlp_logit(latent_emb) 
        logits = logits * mask + (1 - mask) * self.masked_value 
        pol = logits.softmax(dim=-1) 
        gflows = self.mlp_flows(latent_emb).squeeze(dim=-1)  
        return pol, gflows 

class ForwardPolicyLA(ForwardPolicyMeta):

    def __init__(self, hidden_dim, num_leaves, eps=.3, device='cpu'):
        super(ForwardPolicyLA, self).__init__(eps=eps)
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.device = device 
        self.num_nodes = 2 * self.num_leaves - 1 
        self.mlp_emb = BaseNN(self.num_nodes*self.num_nodes, hidden_dim, 2).to(self.device) 
        self.mlp_logit = BaseNN(2*self.hidden_dim, 2*self.hidden_dim, 1, 1).to(self.device) 
        self.mlp_flows = BaseNN(self.num_nodes*self.num_nodes, hidden_dim, 2, 1).to(self.device)  

    def get_latent_emb(self, batch_state): 
        parent_emb = self.mlp_emb(batch_state.adjacency_matrix.flatten(start_dim=1)) 
        childn = list() 
        for child in batch_state.get_children(): 
            childn.append(
                torch.hstack(
                    [parent_emb, self.mlp_emb(child.adjacency_matrix.flatten(start_dim=1))]
                ).unsqueeze(dim=1) 
            )
        input_nn = torch.cat(childn, dim=1)
        return (
            self.mlp_logit(input_nn).squeeze(dim=-1), 
            self.mlp_flows(batch_state.adjacency_matrix.flatten(start_dim=1)).squeeze(dim=-1) 
        ) 
    
    def get_pol(self, logits_flows, mask): 
        logits, flows = logits_flows
        pol = (mask * logits + (1 - mask) * self.masked_value).softmax(dim=-1) 
        return pol, flows 

class BackwardPolicy(nn.Module):

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
    
    def forward(self, batch_state, actions=None):
        uniform_pol = torch.where(batch_state.backward_mask[:, batch_state.num_leaves:] == 1., 1., self.masked_value)
        uniform_pol = uniform_pol.softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze(dim=-1)  
        else: 
            actions = torch.ones((batch_state.batch_size,), 
                        dtype=torch.long) * (batch_state.next_node - batch_state.num_leaves - 1) 
        return actions, - torch.log(uniform_pol[batch_state.batch_ids, actions]) 
