import torch 
import torch.nn as nn 

from var_red_gfn.utils import ForwardPolicyMeta 

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(self, src_size, hidden_dim=128, num_layers=3, eps=.3, device='cpu'): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device) 
        self.src_size = src_size 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers
        self.device = device  
        self.mlp = nn.Sequential(nn.Linear(self.src_size, self.hidden_dim), nn.LeakyReLU())
        # Hidden layers         
        for _ in range(num_layers): 
            self.mlp.append(nn.Linear(self.hidden_dim, self.hidden_dim)) 
            self.mlp.append(nn.LeakyReLU()) 
        # Output layers 
        self.mlp_logit = nn.Linear(self.hidden_dim, self.src_size).to(self.device) 
        self.mlp_flows = nn.Linear(hidden_dim, 1).to(self.device) 
        self.mlp = self.mlp.to(self.device) 

    def get_latent_emb(self, batch_state): 
        return self.mlp(batch_state.unique_input) 
    
    def get_pol(self, latent_emb, mask): 
        pol = (mask * self.mlp_logit(latent_emb) + self.masked_value * (1 - mask)).softmax(dim=-1) 
        flows = self.mlp_flows(latent_emb).squeeze(dim=-1) 
        return pol, flows 
    
class BackwardPolicy(nn.Module): 
    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
        
    def forward(self, batch_state, actions=None): 
        mask = batch_state.backward_mask 
        uniform_pol = (mask * batch_state.unique_input + self.masked_value * (1 - mask)).softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze() 
        return actions, torch.log(uniform_pol[batch_state.batch_ids, actions])     