import torch
import torch.nn as nn
import torch.nn.functional as F


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)        # standard deviation
    return mu + torch.randn_like(std) * std 

# ============================ state module ============================

# atari
class ConvStateEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(ConvStateEncoder, self).__init__()
        # input: # [16, 84, 84] 
        self.fc = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=3, stride=2, padding=1),  # [16, 42, 42] 
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # [32, 21, 21]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # [64, 11, 11]
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 11 * 11, latent_dim)
        )

        
    def forward(self, x):
        if x.dim() == 5:
            B, T, C, H, W =x.shape
            x = x.reshape(B * T, C, H, W)
            x = self.fc(x)
            x = x.reshape(B, T, *x.shape[1:])
        else:
            x = self.fc(x).squeeze(0)
        return x
    
        """z = torch.sqrt(
            torch.tensor(self.latent_dim, dtype=torch.int, device=self.device)
        ) * torch.nn.functional.normalize(z, dim=x.dim()-1)"""
    
    
class ConvStateDecoder(nn.Module):
    def __init__(self, latent_dim):
        super(ConvStateDecoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 64 * 11 * 11),
            nn.ReLU(),
            nn.Unflatten(1, (64, 11, 11)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=0),  # [32, 22, 22]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # [16, 44, 44]
            nn.ReLU(),
            nn.ConvTranspose2d(16, 6, kernel_size=3, stride=2, padding=1, output_padding=1),  # [6, 84, 84]
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.fc(x)

# classic
class ClassicStateEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(ClassicStateEncoder, self).__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(2 * 6 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )


    def forward(self, x):
        if x.dim() == 5:
            x = x.view(x.size(0), x.size(1), -1)
        else:
            x = self.flatten(x).squeeze(0)
        x = self.fc(x)
        #x = F.normalize(x, p=2, dim=-1)
        return x


class ClassicStateDecoder(nn.Module):
    def __init__(self, latent_dim):
        super(ClassicStateDecoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2 * 6 * 7)
        )

    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, 2, 6, 7)
        return x

# mpe
class StateEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(StateEncoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim),
        )
        #self.mu = nn.Linear(latent_dim, latent_dim)
        #self.logvar = nn.Linear(latent_dim, latent_dim)
        
    def forward(self, x):
        x = self.fc(x)
        #mu = self.mu(x)
        #logvar = self.logvar(x)
        #z = reparameterize(mu, logvar)
        return x
      

class StateDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(StateDecoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, output_dim),
        )
        
    def forward(self, x):
        return self.fc(x)


# ============================ action module ============================

class MultiActionEncoder(nn.Module):
    def __init__(self, action_dim, latent_dim, agent_num):
        super(MultiActionEncoder, self).__init__()
        self.emb = nn.Embedding(action_dim, latent_dim // 4)
        self.fc = nn.Linear((latent_dim // 4) * agent_num, latent_dim)
        self.three_dim = False
        self.num = agent_num
        
    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        elif x.dim() == 3:
            batch_size, k, _ = x.shape
            x = x.view(-1, x.size(2))
            self.three_dim = True
        combined_act = torch.cat([self.emb(x[:, i]) for i in range(self.num)], dim=-1) ## dim=-1
        if self.three_dim:
            combined_act = combined_act.view(batch_size, k, combined_act.size(1))

        return self.fc(combined_act)
    

class MultiActionDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim, agent_num, temperature=1.0):
        super(MultiActionDecoder, self).__init__()
        self.agent_num = agent_num
        self.temperature = temperature
        self.fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, output_dim),
            ) for _ in range(agent_num)
        ])
        
    def forward(self, x):
        probs, recons = [], []
        for i, head in enumerate(self.fc):
            d = head(x)
            d = F.log_softmax(d / self.temperature, dim=-1)
            probs.append(d)
            recons.append(torch.argmax(d, dim=-1))

        recon = torch.stack(recons, dim=-1)
        probs = torch.stack(probs, dim=-1)

        return probs, recon


class ActionEncoder(nn.Module):
    def __init__(self, action_dim, latent_dim):
        super(ActionEncoder, self).__init__()
        self.fc = nn.Embedding(action_dim, latent_dim)
        
    def forward(self, x):
        return self.fc(x).squeeze(1) ##
    

class ActionDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim, temperature=1.0):
        super(ActionDecoder, self).__init__()
        self.temperature = temperature
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, output_dim),
        ) 
        
    def forward(self, x):
        x = self.fc(x)
        probs = F.log_softmax(x / self.temperature, dim=-1)
        recon = torch.argmax(x, dim=-1)
        return probs, recon
    

# ============================ other module ============================
class SASDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, agent_num, adv_num, temperature=1.0):
        super(SASDecoder, self).__init__()
        self.agent_num = agent_num
        self.temperature = temperature
        self.embedding = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim))
        
        self.agent_fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim),
            ) for _ in range(agent_num)
        ])

        self.adv_fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim),
            ) for _ in range(adv_num)
        ])


    def agent_recon(self, x):
        probs, recons, logits = [], [], []
        for i, head in enumerate(self.agent_fc):
            d = head(x)
            dp = F.log_softmax(d / self.temperature, dim=-1)
            logits.append(d)
            probs.append(dp)
            recons.append(torch.argmax(dp, dim=-1))

        recon = torch.stack(recons, dim=-1)
        probs = torch.stack(probs, dim=-1)
        logits = torch.stack(logits, dim=-1)
        return recon, probs, logits
    
    
    def adv_recon(self, x):
        probs, recons, logits = [], [], []
        for i, head in enumerate(self.adv_fc):
            d = head(x)
            dp = F.log_softmax(d / self.temperature, dim=-1)
            logits.append(d)
            probs.append(dp)
            recons.append(torch.argmax(dp, dim=-1))

        recon = torch.stack(recons, dim=-1)
        probs = torch.stack(probs, dim=-1)
        logits = torch.stack(logits, dim=-1)
        return recon, probs, logits


    def forward(self, next_state, state):
        x = self.embedding(next_state - state)
        ag_recon, ag_prob, ag_logits = self.agent_recon(x)
        ad_recon, ad_prob, ad_logits = self.adv_recon(x)

        return {'agent': {'recon': ag_recon, 'log_prob': ag_prob, 'logits': ag_logits},
                'adv': {'recon': ad_recon, 'log_prob': ad_prob, 'logits': ad_logits}}
        