import torch
import torch.nn as nn
import torch.nn.functional as F
from emb.emb_module import *


def squeezedim(x, dim):
    if x.dim() == dim:
        return x.squeeze(0)
    else:
        return x


class StateEmb(nn.Module):
    def __init__(self, args):
        super(StateEmb, self).__init__()
        self.env = args.env_name
        self.agent_num = args.agent_num
        self.adv_num = args.adv_num
        self.player_list = list(args.player_list)
        self.action_dim_list = [1] * self.adv_num + [1] * self.agent_num
        
        self.action_dim = args.action_dim
        self.latent_dim = args.model.latent_dim
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        
        # state
        if self.env in ["tennis", "box"]: 
            self.state_encoders = nn.ModuleDict({
                name: ConvStateEncoder(latent_dim = self.latent_dim)
                for name in self.player_list
            })
            
            self.state_decoders = nn.ModuleDict({
                name: ConvStateDecoder(latent_dim = self.latent_dim)
                for name in self.player_list
            })
        
        elif self.env in ['connect4']:
            self.state_encoders = nn.ModuleDict({
                name: ClassicStateEncoder(latent_dim = self.latent_dim)
                for name in self.player_list
            })
            
            self.state_decoders = nn.ModuleDict({
                name: ClassicStateDecoder(latent_dim = self.latent_dim)
                for name in self.player_list
            })
        

    # ========================== Embedding ===============================
    def state_embed(self, state, player_name):
        state = state.to(self.device)
        if self.env in ["tennis", "box"]:
            state = state / 255.0 # 0-1
            state = state.unsqueeze(0) if state.dim() == 3 else state

            if state.dim() == 4:
                state = state.permute(0, 3, 1, 2) # [b, 84, 84, 6] -> [b, 6, 84, 84]
            elif state.dim() == 5:
                state = state.permute(0, 1, 4, 2, 3) # [b, len, 84, 84, 6] -> [b, len, 6, 84, 84]
                
        elif self.env in ['connect4']:
            state = state.unsqueeze(0) if state.dim() == 3 else state
        return self.state_encoders[player_name](state)
        
    
    def state_recon(self, x, dec_name):
        recon = self.state_decoders[dec_name](x) # [batch size, 84, 84, 6] -> [batch size, 6, 84, 84]
        if self.env in ["tennis", "box"]:
            return recon.permute(0, 2, 3, 1) # [batch size, 6, 84, 84] -> [batch size, 84, 84, 6]
        else:
            return recon
    

    # ========================== Train ===============================
    
    def forward(self, state):
        # state
        if self.env in ['connect4']:
            s_embed = []
            for idx, name in enumerate(self.player_list):
                s_embed.append(self.state_embed(state, name))
            sloss = 0.0
            for idx, name in enumerate(self.player_list):
                s_recon = self.state_recon(s_embed[idx], name)
                sloss += self.recon_mse(s_recon, state)
            
        
        else:
            s_chunks = torch.unbind(state, dim=1)
            s_embed = []
            for name, input_tensor in zip(self.player_list, s_chunks):
                s_embed.append(self.state_embed(input_tensor, name))
                
            sloss = 0.0
            for idx, name in enumerate(self.player_list):
                s_recon = self.state_recon(s_embed[idx], name)
                sloss += self.recon_mse(s_recon, s_chunks[idx])
        
        return sloss
    

    def recon_mse(self, recon_s, state):
        if self.env in ["tennis", "box"]:
            return F.mse_loss(recon_s, state / 255)
        else:
            return F.mse_loss(recon_s, state)
    