import torch 
import torch.nn as nn 

from gfn.utils import ForwardPolicyMeta, GammaFuncMeta, BaseNN

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, seq_size, src_size, hidden_dim, num_layers, device='cpu', eps=.3): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device) 
        self.seq_size = seq_size 
        self.src_size = src_size 
        self.hidden_dim = hidden_dim 
        self.device = device 
        self.num_layers = num_layers 

        self.mlp_logit = BaseNN(self.seq_size + 1, self.hidden_dim, self.num_layers, 
                self.src_size + 1).to(self.device)  
        self.mlp_flows = BaseNN(self.seq_size + 1, self.hidden_dim, self.num_layers, 1).to(self.device)  
    
    def get_latent_emb(self, batch_state): 
        state = batch_state.state.type(torch.get_default_dtype()) 
        return (
            self.mlp_logit(state), self.mlp_flows(state).squeeze(dim=-1) 
        ) 
     
    def get_pol(self, logits_flows, mask): 
        logits, flows = logits_flows 
        logits = logits * mask + self.masked_value * (1 - mask)     
        pol = torch.softmax(logits, dim=-1) 
        return pol, flows

class ForwardPolicyLA(ForwardPolicyMeta): 

    def __init__(self, seq_size, src_size, hidden_dim, num_layers, device='cpu', eps=.3): 
        super(ForwardPolicyLA, self).__init__(eps=eps, device=device) 
        self.input_dim = seq_size + 1 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers 

        self.mlp_emb = BaseNN(self.input_dim, self.hidden_dim, self.num_layers).to(self.device)  
        self.mlp_logit = BaseNN(2*self.hidden_dim, 2*self.hidden_dim, 1, 1).to(self.device)   
        self.mlp_flows = BaseNN(self.input_dim, self.hidden_dim, self.num_layers, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        dtype = torch.get_default_dtype() 
        parent_emb = self.mlp_emb(
            batch_state.unique_input.type(dtype)   
        ) 
        emb = list() 
        for child in batch_state.get_children(): 
            emb.append(
                torch.hstack([
                    parent_emb, 
                    self.mlp_emb(child.unique_input.type(dtype))
                ]).unsqueeze(dim=1)  
            ) 
        emb = torch.cat(emb, dim=1) 
        return (
            self.mlp_logit(emb).squeeze(dim=-1), 
            self.mlp_flows(batch_state.unique_input.type(dtype)).squeeze(dim=-1)    
        ) 

    def get_pol(self, logits_flows, mask): 
        logits, flows = logits_flows 
        pol = (logits * mask + (1 - mask) * self.masked_value).softmax(dim=-1) 
        return pol, flows 

class GammaFuncDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        seq_size, src_size = batch_state_t.seq_size, batch_state_t.src_size 
        ones = torch.ones((batch_state_t.batch_size,), device=batch_state_t.device) 
        curr_size = batch_state_tp1.curr_idx 
        num_descendants = ( src_size ** (seq_size - curr_size + 1) - 1 ) / (src_size - 1) 
        return torch.log(num_descendants) 
        # return (batch_state_t.curr_idx + 1).log() 

class GammaFuncInvDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        seq_size, src_size = batch_state_t.seq_size, batch_state_t.src_size 
        ones = torch.ones((batch_state_t.batch_size,), device=batch_state_t.device) 
        curr_size = batch_state_tp1.curr_idx 
        num_descendants = ( src_size ** (seq_size - curr_size + 1) - 1 ) / (src_size - 1) 
        return - torch.log ( num_descendants  + 1 ) 

class LearnableGamma(GammaFuncMeta): 

    def __init__(self, input_dim, hidden_dim, total_iters=1, device='cpu'): 
        super(LearnableGamma, self).__init__(total_iters=total_iters) 
        self.device = device 
        self.mlp = BaseNN(input_dim, hidden_dim, 2, output_dim=1, act=nn.LeakyReLU()).to(self.device) 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        to_default_dtype = lambda tensor: tensor.type(torch.get_default_dtype()) 
        return self.mlp(
            torch.hstack([
                to_default_dtype(batch_state_t.unique_input), to_default_dtype(batch_state_t.unique_input)  
            ])
        ).squeeze(dim=1) 

class BackwardPolicy(nn.Module): 

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 

    def forward(self, batch_state, actions=None): 
        return torch.zeros((batch_state.batch_size,), device=self.device), \
            torch.zeros((batch_state.batch_size,), device=self.device) 
