import torch 
import torch.nn as nn 

STOP_ACTION_INDEX = 0 

class ForwardFlow(nn.Module):

    def __init__(self, hidden_dim, masked_value=-1e5):
        super(ForwardFlow, self).__init__()
        self.hidden_dim = hidden_dim
        self.mlp = nn.Sequential(
            nn.Linear(2, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, 3)
        )
        self.masked_value = masked_value

    def forward(self, batch_state, off_policy=False, actions=None):
        logits = self.mlp(batch_state.pos)
        logits = torch.where(batch_state.mask == 1, logits, self.masked_value)
        probs = torch.softmax(logits, dim=-1)
        if actions is not None:
            return actions, torch.log(probs[batch_state.batch_ids, actions]), torch.log(probs[batch_state.batch_ids, STOP_ACTION_INDEX])
        if off_policy:
            uniform = torch.where(batch_state.mask == 1, 1, self.masked_value)
            actions = torch.multinomial(torch.softmax(uniform, dim=-1), num_samples=1)
        else:
            actions = torch.multinomial(probs, num_samples=1)
        actions = actions.squeeze()
        return actions, torch.log(probs[batch_state.batch_ids, actions]), torch.log(probs[batch_state.batch_ids, STOP_ACTION_INDEX])

class BackwardFlow(nn.Module): 

    def forward(self, batch_state, actions=None): 
        num_backward_actions = torch.gt(batch_state.pos, batch_state.bottomleft).sum(dim=-1) 
        mask = (batch_state.stopped < 1.) 
        return - torch.where(mask, torch.log(num_backward_actions), 0.)  

class StateFlow(nn.Module): 

    def __init__(self, hidden_dim): 
        super(StateFlow, self).__init__() 
        self.hidden_dim = hidden_dim 
        self.mlp = nn.Sequential(
            nn.Linear(2, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1) 
        )

    def forward(self, batch_state): 
        return self.mlp(batch_state.pos).squeeze() 