import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from torch import nn
from delphicORL.networks.transformer import TransformerEnc

class Encoder(nn.Module):
    def __init__(self, in_dim, latent_dim, hidden_dims):
        super(Encoder, self).__init__()
        modules = []

        for h_dim in hidden_dims:
            modules.append(nn.Sequential( nn.Linear(in_dim, h_dim),
                            nn.BatchNorm1d(h_dim),
                            nn.LeakyReLU()))
            in_dim = h_dim
        self.layers = nn.Sequential(*modules)

        self.enc_mu = nn.Linear(hidden_dims[-1], latent_dim) 
        self.enc_log_sigma = nn.Linear(hidden_dims[-1], latent_dim)


    def forward(self, x):
        x = self.layers(x)
        mu = self.enc_mu(x)
        log_sigma = self.enc_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return torch.distributions.Normal(loc=mu, scale=sigma)

class Decoder(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dims):
        super(Decoder, self).__init__()

        hidden_dims.reverse()

        self.layers = []
        self.bns = []
        for h_dim in hidden_dims:
            self.layers.append(nn.Linear(in_dim, h_dim))
            self.bns.append(nn.BatchNorm1d(h_dim))
            in_dim = h_dim
        self.layers = nn.ModuleList(self.layers)
        self.bns = nn.ModuleList(self.bns)
        self.output_layer_mu = nn.Linear(hidden_dims[-1], out_dim)
        self.output_layer_sigma = nn.Linear(hidden_dims[-1], out_dim)

    def forward(self, x):
        for layer, bn in zip(self.layers, self.bns):
            x = layer(x)
            if len(x.shape) > 2:
                x = x.permute(0, 2, 1)
            x = bn(x)
            if len(x.shape) > 2:
                x = x.permute(0, 2, 1)
            x = F.leaky_relu(x)
        mu = self.output_layer_mu(x)
        sigma = self.output_layer_sigma(x)
        return torch.distributions.Normal(mu, sigma)

class BetaVAE(nn.Module):
    def __init__(self, kl_weight, input_dim, latent_dim, max_len=1, hidden_size= [128, 64, 32]):
        super().__init__()
        self.kl_weight = kl_weight
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.max_len = max_len

        self.encoder = Encoder(max_len*input_dim, latent_dim, hidden_size)
        self.decoder = Decoder(latent_dim, max_len*input_dim, hidden_size)

    def forward(self, state):
        q_z = self.encoder(state)
        z = q_z.rsample()
        return self.decoder(z), q_z # outputs two distributions

    def format_input(self, x):
        if x is None:
            return x
        if len(x.shape) < 2:
            return x
        if x.shape[1] < self.max_len:
            pad_len = self.max_len - x.shape[1]
            if len(x.shape) == 3:
                x = torch.nn.functional.pad(x, ((0,0,0, pad_len)))
            else:
                x = torch.nn.functional.pad(x, (0, pad_len), value=0)

        elif x.shape[1] > self.max_len:
            x = x[:, :self.max_len]
        return x

    def loss(self, x,
            kl_weight = None,
            mask=None,
            reduce='mean'):
        if kl_weight is None:
            kl_weight = self.kl_weight

        x = self.format_input(x)
        mask = self.format_input(mask).to(torch.bool)
        x = x.reshape(x.shape[0], -1)
        p_x, q_z = self.forward(x)
        log_likelihood = p_x.log_prob(x)

        kl = torch.distributions.kl_divergence(
            q_z, 
            torch.distributions.Normal(0, 1.)
        ).sum(-1)

        if reduce=='mean':
            if mask is not None:
                bs, max_len = mask.shape
                assert (max_len == self.max_len)
                log_likelihood = log_likelihood.reshape(bs, self.max_len, -1)[mask]

            log_likelihood = log_likelihood.sum(-1).mean()
            kl = kl.mean()

            loss = -(log_likelihood - kl_weight * kl)

        elif reduce is None:
            if mask is not None:
                bs, max_len = mask.shape
                log_likelihood = log_likelihood.reshape(bs, self.max_len, -1).sum(-1)
                loss = -(log_likelihood - kl_weight * kl.unsqueeze(-1).expand((-1,max_len)))
                loss = loss[mask]
                log_likelihood = log_likelihood[mask]

            else:
                log_likelihood = log_likelihood.sum(-1)
                loss = -(log_likelihood - kl_weight * kl)

        return {'loss': loss, 'Reconstruction_LL':log_likelihood.detach(), 'KL':-kl.detach()}
    


class Transformer_VAE(BetaVAE):
    """" VAE with Transformer Encoder architecture."""
    def __init__(self, input_dim, latent_dim, max_len=1, hidden_size=[128, 64, 32]):
        super().__init__(input_dim, latent_dim, max_len, hidden_size)

        self.encoder = TransformerEnc(input_dim, d_model=latent_dim, 
                            d_hid=hidden_size[0],
                            max_history_len=10)

    def forward(self, state, mask):
        q_z = self.encoder(state, pad_mask=mask)
        
        z = q_z.rsample()
        return self.decoder(z), q_z 
