import torch 
import torch.nn as nn 

from gfn.utils import ForwardPolicyMeta, GammaFuncMeta, BaseNN

class ForwardPolicy(ForwardPolicyMeta):

    def __init__(self, hidden_dim, eps=.3, device='cpu'):
        super(ForwardPolicy, self).__init__(eps=eps, device=device)
        self.hidden_dim = hidden_dim
        self.device = device  
        self.mlp = nn.Sequential(
            nn.Linear(2, hidden_dim), nn.LeakyReLU(), 
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()
        ).to(self.device) 
        self.mlp_logits = nn.Linear(hidden_dim, 3).to(self.device) 
        self.mlp_gflows = nn.Linear(hidden_dim, 1).to(self.device) 

    def get_latent_emb(self, batch_state): 
        return self.mlp(batch_state.pos) 

    def get_pol(self, latent_emb, mask): 
        logits = self.mlp_logits(latent_emb) 
        is_stop_logit_negative = (logits[:, -1] < 0).long()  
        logits[:, -1] = (logits[:, -1] * is_stop_logit_negative * 5.) + \
            (1 - is_stop_logit_negative) * logits[:, -1] / 5. 
        logits = logits * mask + (1 - mask) * self.masked_value 
        gflows = self.mlp_gflows(latent_emb).squeeze(dim=-1)  
        pol = torch.softmax(logits, dim=-1) 
        return pol, gflows  

class BackwardPolicy(nn.Module): 

    masked_value = -1e5 
    def __init__(self, device): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 
        
    def forward(self, batch_state, actions=None, initial_state=None):             
        uniform_pol = torch.where(batch_state.backward_mask == 1., 1., 
                                self.masked_value).softmax(dim=-1) 
        if actions is None: 
            actions = torch.multinomial(uniform_pol, num_samples=1, replacement=True) 
            actions = actions.squeeze(dim=-1) 
            actions = torch.where(batch_state.stopped==1., 2, actions) 
        uniform_pol = torch.hstack([uniform_pol, torch.ones((batch_state.batch_size, 1), 
                                                            device=self.device)])
        return actions, torch.log(uniform_pol[batch_state.batch_ids, actions]) 

class GammaFuncDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        return torch.log(
            (batch_state_t.width - batch_state_t.pos[:, 0]) * (batch_state_t.height - batch_state_t.pos[:, 1]) + 1 
        )  
    
class GammaFuncInvDepth(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        return - torch.log(
            (batch_state_t.width - batch_state_t.pos[:, 0]) * (batch_state_t.height - batch_state_t.pos[:, 1]) + 1     
        )  

class LearnableGamma(GammaFuncMeta): 

    def __init__(self, hidden_dim, total_iters=1, device='cpu'):
        super(LearnableGamma, self).__init__(total_iters=total_iters)
        self.hidden_dim = hidden_dim 
        self.device = device 
        
        self.mlp = BaseNN(2, hidden_dim, 2, 1).to(self.device) 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        return self.mlp(
                torch.hstack([
                    batch_state_t.pos, batch_state_tp1.pos  
                ]) 
            ).squeeze(dim=1) 