import torch
import torch.nn as nn
import numpy as np

class AbstractDistribution:
    def sample(self):
        raise NotImplementedError()

    def mode(self):
        raise NotImplementedError()


class DiracDistribution(AbstractDistribution):
    def __init__(self, value):
        self.value = value

    def sample(self):
        return self.value

    def mode(self):
        return self.value


class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * (torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar).sum(dim=[1,2,3])
            else:
                return 0.5 * (
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar).sum(dim=[1,2,3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )


class Encoder(nn.Module):
    def __init__(
        self, 
        obs_shape, 
        latent_channels: int, 
        num_filters: int
    ) -> None:
        super().__init__()
        self.obs_shape = obs_shape
        self.latent_channels = latent_channels
        self.convs = nn.Sequential(
            nn.Conv2d(obs_shape[0]//3, num_filters, 3, stride=1, padding=1), 
            nn.ReLU(), 
            nn.Conv2d(num_filters, num_filters, 3, stride=2, padding=1), 
            nn.ReLU(), 
            nn.Conv2d(num_filters, num_filters, 3, stride=2, padding=1), 
            nn.ReLU(), 
            nn.Conv2d(num_filters, latent_channels*2, 3, stride=2, padding=2), 
            nn.ReLU(), 
        )
        self.out_conv = nn.Conv2d(latent_channels*2, latent_channels*2, 3, stride=1, padding=1)

    def forward(self, obs):
        obs = obs / 255. - 0.5
        obs = self.convs(obs)
        obs = self.out_conv(obs)
        return obs
    

class Decoder(nn.Module):
    def __init__(
        self, 
        obs_shape, 
        latent_channels: int, 
        num_filters: int
    ) -> None:
        super().__init__()
        self.obs_shape = obs_shape
        self.latent_channels = latent_channels
        self.deconvs = nn.Sequential(
            nn.ConvTranspose2d(latent_channels, num_filters, 3, stride=2, padding=2), 
            nn.ReLU(), 
            nn.ConvTranspose2d(num_filters, num_filters, 3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(), 
            nn.ConvTranspose2d(num_filters, num_filters, 3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(), 
            nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1, padding=1), 
            nn.ReLU()
        )
        self.out_conv = nn.Conv2d(num_filters, obs_shape[0]//3, 3, stride=1, padding=1)
        
    def forward(self, h):
        h = self.deconvs(h)
        h = self.out_conv(h)
        h = torch.tanh(h)
        return h


class VAE(nn.Module):
    def __init__(
        self, 
        obs_shape, 
        latent_channels: int=16, 
        ae_num_filters: int=32, 
    ) -> None:
        super().__init__()
        self.encoder = Encoder(obs_shape, latent_channels, ae_num_filters)
        self.decoder = Decoder(obs_shape, latent_channels, ae_num_filters)
        
    def encode(self, x):
        h = self.encoder(x)
        posterior = DiagonalGaussianDistribution(h)
        return posterior
    
    def decode(self, z):
        z = self.decoder(z)
        return z
    
    def forward(self, x, sample_posterior=True, forward_decoder=False):
        posterior = self.encode(x)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        if forward_decoder:
            dec = self.decode(z)
            return z, posterior, dec
        else:
            return z, posterior
        

class Scaler(nn.Module):
    def __init__(self, activate=False):
        super().__init__()
        self.activate = activate
        self.initialized = False
        self.scale_factor = 1.0

    def init(self, batch):
        self.initialized = True
        if not self.activate:
            return
        self.scale_factor = batch.flatten().std()

    def forward(self, x, reverse=False):
        if not self.activate: 
            return x
        if reverse:
            return x / self.scale_factor
        else:
            return x * self.scale_factor
        
