import torch 
import torch.nn as nn 

from gfn.utils import ForwardPolicyMeta, BaseNN, GammaFuncMeta 

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, eps=.3, device='cpu'): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device)  
        self.device = device  
        self.mlp_logit = BaseNN(src_size, hidden_dim, num_layers, src_size).to(self.device) 
        # Output layers 
        self.mlp_flows = BaseNN(src_size, hidden_dim, num_layers, 1).to(self.device) 

    def get_latent_emb(self, batch_state): 
        return (
            self.mlp_logit(batch_state.unique_input), 
            self.mlp_flows(batch_state.unique_input).squeeze(dim=1)  
        )  
    
    def get_pol(self, logit_flows, mask):
        logit, flows = logit_flows  
        pol = (mask * logit + self.masked_value * (1 - mask)).softmax(dim=-1)  
        return pol, flows 

class ForwardPolicyLA(ForwardPolicy): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, eps=.3, device='cpu'): 
        super(ForwardPolicyLA, self).__init__(src_size, eps=eps, device=device) 
        self.mlp_emb = BaseNN(src_size, hidden_dim, num_layers).to(self.device)
        self.mlp_logit = BaseNN(2*hidden_dim, 2*hidden_dim, 1, 1).to(self.device) 
        self.mlp_flows = BaseNN(src_size, hidden_dim, num_layers, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        nn_input = list() 
        p_emb = self.mlp_emb(batch_state.unique_input) 
        for child in batch_state.get_children(): 
            nn_input.append(torch.hstack([
                p_emb, self.mlp_emb(child.unique_input) 
            ]).unsqueeze(dim=1))  
        nn_input = torch.cat(nn_input, dim=1) 
        return (
            self.mlp_logit(nn_input).squeeze(dim=-1), 
            self.mlp_flows(batch_state.unique_input).squeeze(dim=-1) 
        )

class LearnablePotential(nn.Module): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, device='cpu'): 
        super(LearnablePotential, self).__init__()
        self.mlp = BaseNN(2 * src_size, hidden_dim, num_layers+1, output_dim=1).to(device) 

    def forward(self, batch_state_t, batch_state_tp1): 
        input_nn = torch.hstack([
            batch_state_t.unique_input, batch_state_tp1.unique_input  
        ])
        return self.mlp(input_nn).squeeze(dim=-1) 

class FixedPotential(nn.Module): 
    
    def __init__(self, *args, **kwargs): # type: ignore 
        super(FixedPotential, self).__init__() 

    @torch.no_grad() 
    def forward(self, batch_state_t, batch_state_tp1): 
        return - batch_state_tp1.log_reward() + batch_state_t.log_reward() 

class GammaFuncDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        log_fact = lambda n: 0 if (n == 1).all() else torch.log(n) + log_fact(n - 1) 
        log_bino = lambda n, k: log_fact(n) - log_fact(n - k) - log_fact(k) 
        ones = torch.ones((batch_state_t.batch_size,), device=batch_state_t.device) 
        return log_bino(
            ones * batch_state_t.src_size, batch_state_t.set_size - batch_state_t.unique_input.sum(dim=1)  
        )
        # return (batch_state_tp1.unique_input.sum(dim=1)).pow(2).log() 
    
class GammaFuncInvDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        log_fact = lambda n: 0 if (n == 1).all() else torch.log(n) + log_fact(n - 1) 
        log_bino = lambda n, k: log_fact(n) - log_fact(n - k) - log_fact(k) 
        ones = torch.ones(batch_state_t.batch_size, device=batch_state_t.device) 
        return - log_bino(
            ones * batch_state_t.src_size, batch_state_t.set_size - batch_state_t.unique_input.sum(dim=1) 
        )
        # return (batch_state_t.set_size - batch_state_t.unique_input.sum(dim=1)).pow(2) 

class LearnableGamma(GammaFuncMeta): 

    def __init__(self, input_dim, hidden_dim, total_iters=1, device='cpu'):
        super(LearnableGamma, self).__init__(total_iters=total_iters)
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim 
        self.device = device 
        
        self.mlp = BaseNN(input_dim, hidden_dim, 2, 1).to(self.device) 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        return self.mlp(
                torch.hstack([
                    batch_state_t.unique_input, batch_state_tp1.unique_input 
                ]) 
            ).squeeze(dim=1) 
        
class BackwardPolicy(nn.Module): 
    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
        
    def forward(self, batch_state, actions=None): 
        mask = batch_state.backward_mask 
        uniform_pol = (mask * batch_state.unique_input + self.masked_value * (1 - mask)).softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze() 
        return actions, torch.log(uniform_pol[batch_state.batch_ids, actions])     
