import torch 
import torch.nn as nn 

class ForwardFlow(nn.Module): 

    def __init__(self, emb_dim, hidden_dim, warehouse_size): 
        super(ForwardFlow, self).__init__() 
        self.embeddings = nn.Embedding(warehouse_size + 1, emb_dim)  
        self.mlp = nn.Sequential(nn.Linear(emb_dim, hidden_dim), nn.LeakyReLU(), 
                        nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, warehouse_size)) 

    def forward(self, batch_state, off_policy=False, actions=None): 
        embeddings = self.embeddings(batch_state.state) 
        set_embeddings = embeddings.mean(dim=1) 
        out = self.mlp(set_embeddings) 
        probs = torch.softmax(out, dim=-1) 
        if actions is not None: 
            return actions, torch.log(probs[batch_state.batch_ids, actions]) 
        if off_policy: 
            actions = torch.multinomial(torch.ones_like(probs), num_samples=1) 
        else: 
            actions = torch.multinomial(probs, num_samples=1) 
        actions = actions.squeeze() 
        return actions, torch.log(probs[batch_state.batch_ids, actions]) 

class StateFlow(nn.Module): 

    def __init__(self, emb_dim, hidden_dim, warehouse_size): 
        super(StateFlow, self).__init__() 
        self.emb_dim = emb_dim 
        self.hidden_dim = hidden_dim 
        self.warehouse_size = warehouse_size 

        self.embeddings = nn.Embedding(warehouse_size + 1, emb_dim) 
    
        self.mlp = nn.Sequential(nn.Linear(emb_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, hidden_dim), 
                        nn.LeakyReLU(), nn.Linear(hidden_dim, 1)) 

    def forward(self, batch_state): 
        embeddings = self.embeddings(batch_state.state).sum(dim=1)  
        state_log_flow = self.mlp(embeddings).squeeze() 
        return state_log_flow 

# Possibly yet atypically parametrized 
class BackwardFlow(nn.Module): 

    def forward(self, batch_state, actions): 
        return - torch.log(batch_state.size) 