import torch 
import torch.nn as nn 

class ForwardFlow(nn.Module): 

    def __init__(self, emb_dim, vocab_size, hidden_dim, masked_value=-1e5, tol=1e-8): 
        super(ForwardFlow, self).__init__() 
        self.emb_dim = emb_dim 
        self.vocab_size = vocab_size 
        self.hidden_dim = hidden_dim 

        # `vocab_size` + token for initialization 
        self.embeddings = nn.Embedding(vocab_size + 1, emb_dim)  
        self.rnn = nn.LSTM(input_size=self.emb_dim, hidden_size=self.hidden_dim, batch_first=True, num_layers=1, bidirectional=True) 
        self.mlp = nn.Sequential(nn.Linear(2 * self.hidden_dim, self.hidden_dim), nn.LeakyReLU(), 
                    nn.Linear(self.hidden_dim, self.hidden_dim), nn.LeakyReLU(), nn.Linear(self.hidden_dim, self.vocab_size + 1))

        self.masked_value = masked_value 
        self.tol = tol 
        
    def forward(self, batch_state, off_policy=False, actions=None): 
        embeddings = self.embeddings(batch_state.state) # (N, S, D)  
        embeddings = embeddings * batch_state.mask_padding[..., None] 
        out, _ = self.rnn(embeddings) 

        out = out[:, -1, :]
        logits = batch_state.mask * self.mlp(out) + (1 - batch_state.mask) * self.masked_value 

        # Continue the process
        probs = torch.softmax(logits, dim=-1) 

        if actions is not None: 
            return actions, torch.log(probs[batch_state.batch_ids, actions]), torch.log(probs[batch_state.batch_ids, 0]) 
        
        if off_policy: 
            uniform = torch.where(batch_state.mask == 1., 1., self.masked_value)  
            actions = torch.multinomial(torch.softmax(uniform, dim=-1), num_samples=1) 
        else: 
            actions = torch.multinomial(probs, num_samples=1) 
        
        actions = actions.squeeze() 

        return actions, torch.log(probs[batch_state.batch_ids, actions]), torch.log(probs[batch_state.batch_ids, 0]) 

class BackwardFlow(nn.Module): 

    def forward(self, batch_state, actions): 
        return torch.zeros((batch_state.batch_size,)) 

class StateFlow(nn.Module): 
    
    def __init__(self, hidden_dim, emb_dim, vocab_size): 
        super(StateFlow, self).__init__() 
        self.hidden_dim = hidden_dim 
        self.embeddings = nn.Embedding(num_embeddings=vocab_size+1, embedding_dim=emb_dim) 
        self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=True) 
        self.mlp = nn.Sequential(nn.Linear(2*hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1)) 

    def forward(self, batch_state): 
        embeddings = self.embeddings(batch_state.state) 
        out, _ = self.rnn(embeddings) 
        out = self.mlp(out[:, -1, :]) 
        return out.squeeze() 