"""
Module containing the main VAE class.
"""
import torch
from torch import nn, optim
from torch.nn import functional as F

from disvae.utils.initialization import weights_init
from .encoders import get_encoder
from .decoders import get_decoder

MODELS = ["Burgess", "Locatello",]


def init_specific_model(model_type, img_size, latent_dim, group, width=256):
    """Return an instance of a VAE with encoder and decoder from `model_type`."""
    model_type = model_type.lower().capitalize()
    if model_type not in MODELS:
        err = "Unkown model_type={}. Possible values: {}"
        raise ValueError(err.format(model_type, MODELS))

    encoder = get_encoder(model_type)
    decoder = get_decoder(model_type)
    if group <=1:
        model = VAE(img_size, encoder, decoder, latent_dim, )
    else:
        model = FVAE(img_size, encoder, decoder, latent_dim, group, width)
    model.model_type = model_type  # store to help reloading
    return model

class VAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim, ):
        """
        Class which defines model and forward pass.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).
        """
        super(VAE, self).__init__()

        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.encoder = encoder(img_size, self.latent_dim)
        self.decoder = decoder(img_size, self.latent_dim)

        self.reset_parameters()

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean


    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        reconstruct = self.decoder(latent_sample)
        return reconstruct, latent_dist, latent_sample

    def reset_parameters(self):
        self.apply(weights_init)

    def sample_latent(self, x):
        """
        Return latent distribution and samples.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_dist, latent_sample


def reparameterize(mean, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mean + std * eps


class Adjuster(nn.Module):
    def __init__(self, p=1):
        super().__init__()
        self.register_buffer('p', torch.ones(1) * p)

    def forward(self, x):
        (mu, logvar) = x

        def backward(grad):
            return grad * self.p

        if mu.requires_grad:
            mu.register_hook(backward)
            logvar.register_hook(backward)
        return mu, logvar


class FVAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim, group, width=256):
        super(FVAE, self).__init__()
        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))
        self.encoder_type = encoder
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.group = group

        self.encoders = nn.Sequential(*[encoder(img_size, self.latent_dim // group)
                                                for _ in range(group)])

        self.phase = 100
        # self.activate_encoder = self.encoders[1]
        self.decoder = decoder(img_size, self.latent_dim, width)
        self.reset_parameters()


    def encoder(self, x, keep=100):
        latent_fixed_dist = [e(x) for e in self.encoders[:keep + 1]]
        for i in range(keep + 1, self.group):
            latent_fixed_dist.append(
                [torch.zeros_like(latent_fixed_dist[0][0]), torch.zeros_like(latent_fixed_dist[0][0])])
        latent_dist = torch.cat([mu for mu, _ in latent_fixed_dist], -1), \
                      torch.cat([logvar for _, logvar in latent_fixed_dist], -1)

        return latent_dist

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        keep = self.phase
        latent_dist = self.encoder(x, keep)
        sample_latent = self.reparameterize(*latent_dist)

        z = sample_latent
        reconstruct = self.decoder(z)
        return reconstruct, latent_dist, z

    def reset_parameters(self):
        self.apply(weights_init)
        for m in self.encoders:
            m.apply(weights_init)

class FVAE1(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim, group):
        super(FVAE1, self).__init__()
        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))
        self.encoder_type = encoder
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.group = group

        self.encoders = nn.Sequential(*[encoder(img_size, self.latent_dim // group)
                                                for _ in range(group)])

        self.phase = 100
        # self.activate_encoder = self.encoders[1]
        self.decoder = decoder(img_size, self.latent_dim)
        self.reset_parameters()


    def encoder(self, x, keep=100):
        latent_fixed_dist = [e(x) for e in self.encoders[:keep + 1]]
        for i in range(keep + 1, self.group):
            latent_fixed_dist.append(
                [torch.zeros_like(latent_fixed_dist[0][0]), torch.zeros_like(latent_fixed_dist[0][0])])
        latent_dist = torch.cat([mu for mu, _ in latent_fixed_dist], -1), \
                      torch.cat([logvar for _, logvar in latent_fixed_dist], -1)

        return latent_dist

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        keep = self.phase
        latent_dist = self.encoder(x, keep)
        sample_latent = self.reparameterize(*latent_dist)

        z = sample_latent
        reconstruct = self.decoder(z)
        return reconstruct, latent_dist, z

    def reset_parameters(self):
        self.apply(weights_init)
        for m in self.encoders:
            m.apply(weights_init)