""" Contains code for the shapes model """

import itertools
import numpy as np
import torch
from torch import nn, distributions
from torchvision.utils import make_grid

# My imports
from weighted_retraining.models import BaseVAE, UnFlatten


class ShapesVAE(BaseVAE):
    """ Convolutional VAE for encoding/decoding 64x64 images """

    def __init__(self, hparams):
        super().__init__(hparams)

        # Set up encoder and decoder
        self.encoder = nn.Sequential(
            # Many convolutions
            nn.Conv2d(
                in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=4, out_channels=8, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            # Flatten and FC layers
            nn.Flatten(),
            nn.Linear(in_features=256, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2 * self.latent_dim),
        )

        self.decoder = nn.Sequential(
            # FC layers
            nn.Linear(in_features=self.latent_dim, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=256),
            nn.ReLU(),
            # Unflatten
            UnFlatten(16, 4),
            # Conv transpose layers
            nn.ConvTranspose2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=32,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=16,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16,
                out_channels=16,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16,
                out_channels=8,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=8,
                out_channels=8,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=8,
                out_channels=8,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1,
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=8,
                out_channels=1,
                kernel_size=3,
                padding=1,
                stride=1,
                output_padding=0,
            ),
        )

        # self.latent_dim = 32
        self.fc5 = torch.nn.Linear(1, self.latent_dim)
        self.lengthscale = torch.nn.Parameter(torch.ones(self.latent_dim))

    def encode_to_params(self, x):
        enc_output = self.encoder(x)
        mu, logstd = enc_output[:, : self.latent_dim], enc_output[:, self.latent_dim :]
        return mu, logstd

    def decoder_loss(self, z, x_orig):
        """ return negative Bernoulli log prob """
        logits = self.decoder(z)
        dist = distributions.Bernoulli(logits=logits)
        return -dist.log_prob(x_orig).sum() / z.shape[0]

    def decoder_loss_with_reconx(self, z, x_orig):
        """ return negative Bernoulli log prob """
        logits = self.decoder(z)
        dist = distributions.Bernoulli(logits=logits)
        probs = torch.sigmoid(logits)
        return -dist.log_prob(x_orig).sum() / z.shape[0], dist.sample() + probs - probs.detach()

    def decode_deterministic(self, z: torch.Tensor) -> torch.Tensor:
        logits = self.decoder(z)
        return torch.sigmoid(logits)

    def validation_step(self, *args, **kwargs):
        super().validation_step(*args, **kwargs)

        # Visualize latent space
        self.visualize_latent_space(20)

    def forward(self, x):
        """ calculate the VAE ELBO """
        mu, logstd = self.encode_to_params(x)
        encoder_distribution = torch.distributions.Normal(
            loc=mu, scale=torch.exp(logstd)
        )
        z_sample = encoder_distribution.rsample()
        reconstruction_loss = self.decoder_loss(z_sample, x)

        # Manual formula for kl divergence (more numerically stable!)
        kl_div = 0.5 * (torch.exp(2 * logstd) + mu.pow(2) - 1.0 - 2 * logstd)
        kl_loss = kl_div.sum() / z_sample.shape[0]

        # Final loss
        loss = reconstruction_loss + self.beta * kl_loss

        # Logging
        if self.logging_prefix is not None:
            self.log(
                f"rec/{self.logging_prefix}",
                reconstruction_loss,
                prog_bar=self.log_progress_bar,
            )
            self.log(
                f"kl/{self.logging_prefix}", kl_loss, prog_bar=self.log_progress_bar
            )
            self.log(f"loss/{self.logging_prefix}", loss)
        return loss

    def rbf_kernel(self, x):
        """
        Compute the RBF kernel between two sets of inputs.
        Args:
            x1: Tensor of shape (n, d)
            x2: Tensor of shape (m, d)
            lengthscale: Lengthscale parameter (scalar or tensor)
        Returns:
            Covariance matrix of shape (n, m)
        """

        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)

        # Compute the squared Euclidean distance between each pair of points
        dists = torch.cdist(x, x, p=2) ** 2
        # Compute the RBF kernel
        dists = dists.repeat(self.latent_dim, 1, 1)
        lengthscale = self.lengthscale.unsqueeze(-1).unsqueeze(-1).repeat(1, batch_size, batch_size)
        cov_matrix = torch.exp(-dists / (2 * lengthscale ** 2))
        return cov_matrix

    def forward_with_reconx(self, x, y):
        """ calculate the VAE ELBO """
        mu, logstd = self.encode_to_params(x)
        encoder_distribution = torch.distributions.Normal(
            loc=mu, scale=torch.exp(logstd)
        )
        z_sample = encoder_distribution.rsample()
        reconstruction_loss, reconx = self.decoder_loss_with_reconx(z_sample, x)

        # Manual formula for kl divergence (more numerically stable!)
        kl_div = 0.5 * (torch.exp(2 * logstd) + mu.pow(2) - 1.0 - 2 * logstd)
        kl_loss = kl_div.sum() / z_sample.shape[0]

        # Final loss
        loss = reconstruction_loss + kl_loss

        covariance = self.rbf_kernel(self.fc5(y.unsqueeze(1)+0.0))

        # Logging
        if self.logging_prefix is not None:
            self.log(
                f"rec/{self.logging_prefix}",
                reconstruction_loss,
                prog_bar=self.log_progress_bar,
            )
            self.log(
                f"kl/{self.logging_prefix}", kl_loss, prog_bar=self.log_progress_bar
            )
            self.log(f"loss/{self.logging_prefix}", loss)

        return loss, reconx, mu, logstd, covariance

    def visualize_latent_space(self, nrow: int) -> torch.Tensor:

        # Currently only support 2D manifold visualization
        if self.latent_dim == 2:

            # Create latent manifold
            unit_line = np.linspace(-4, 4, nrow)
            latent_grid = list(itertools.product(unit_line, repeat=2))
            latent_grid = np.array(latent_grid, dtype=np.float32)
            z_manifold = torch.as_tensor(latent_grid, device=self.device)

            # Decode latent manifold
            with torch.no_grad():
                img = self.decode_deterministic(z_manifold).detach().cpu()
            img = torch.clamp(img, 0.0, 1.0)

            # Make grid
            img = make_grid(img, nrow=nrow, padding=5, pad_value=0.5)

            # Log image
            self.logger.experiment.add_image(
                "latent manifold", img, global_step=self.global_step
            )
