import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, state_dim, args):
        super(Encoder, self).__init__()
        self.args = args
        self.state_dim = state_dim
        self.action_embed_net = nn.Linear(self.args.n_actions, self.args.action_embed_dim)
        
        self.encoder = nn.Sequential(
            nn.Linear(self.state_dim + self.args.action_embed_dim * self.args.n_agents, self.args.encoder_embed_dim),
            nn.ReLU(),
            nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim),
            nn.ReLU(),
            nn.Linear(self.args.encoder_embed_dim, self.state_dim)
        )
        
        self.action_decoder = nn.ModuleList([nn.Sequential(
            nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim),
            nn.ReLU(),
            nn.Linear(self.args.encoder_embed_dim, self.args.n_actions)
        ) for _ in range(self.args.n_agents)])
        

    def encode(self, states, actions):
        action_embeds = self.action_embed_net(actions).flatten(-2)
        return self.encoder(torch.cat([states, action_embeds], dim=-1))
    

    def decode(self, latent_states):
        recon_states = self.decoder(latent_states)
        recon_actions = []
        for n in range(self.args.n_agents):
            recon_action = self.action_decoder[n](latent_states)
            recon_actions.append(recon_action)
        recon_actions = torch.stack(recon_actions, dim=-2)
        recon_actions = torch.log_softmax(recon_actions, dim=-1)
        return recon_states, recon_actions