"""Encoder and Decoder for images."""

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from pdisvae.models.model import Encoder


class MLPDecoder(nn.Module):
    def __init__(self, img_size: tuple[int, int, int], n_components: int):
        """MLP Decoder.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32).
        n_components : int
            Dimensionality of latent output.
        """
        super().__init__()

        # Layer parameters
        hidden_dim = 128
        self.n_components = n_components
        self.img_size = img_size

        # Fully connected layers
        self.fc1 = nn.Linear(self.n_components, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, np.prod(self.img_size))

    def forward(self, z):
        batch_size = z.size(0)

        # Fully connected layers with tanh activations
        x_pred_mean = torch.sigmoid(self.fc2(torch.tanh((self.fc1(z)))))
        x_pred_mean = x_pred_mean.view(batch_size, *self.img_size)
        return x_pred_mean

    def log_prob(self, x_pred_mean: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """The log probability of the observation given the predicted observation.

        Parameters
        ----------
        x_pred_mean : torch.Tensor of shape (*, n_channels, height, width)
            The mean of the predicted observation.
        x : torch.Tensor of shape (*, n_channels, height, width)
            The observation.

        Returns
        -------
        torch.Tensor of shape (*,)
            The log probability of the observation given the predicted observation.
        """
        return -F.binary_cross_entropy(x_pred_mean, x, reduction="none").sum(
            dim=(-1, -2, -3)
        )


class MLPEncoder(Encoder):
    def __init__(
        self,
        img_size: tuple[int, int, int],
        n_components: int,
        n_total_samples: int | None = None,
    ):
        """MLP Encoder.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32).
        """
        super().__init__(n_components, n_total_samples)

        # Layer parameters
        hidden_dim = 128
        self.img_size = img_size

        # Fully connected layers
        self.fc1 = nn.Linear(np.prod(self.img_size), hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, self.n_components)

    def forward(self, x):
        batch_size = x.size(0)

        # Fully connected layers with tanh activations
        x = x.view(batch_size, -1)
        x = torch.tanh(self.fc1(x))
        z_pred_mean = self.fc2(x)
        # return z_pred_mean, self.compute_log_std(z_pred_mean)
        return z_pred_mean, self.log_std[None, :].expand_as(z_pred_mean)


class BurgessDecoder(nn.Module):
    def __init__(self, img_size: tuple[int, int, int], n_components: int):
        """Decoder of the model proposed in [1].

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).

        n_components : int
            Dimensionality of latent output.

        Model Architecture (transposed for decoder)
        ------------
        - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
        - 2 fully connected layers (each of 256 units)
        - Latent distribution:
            - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)

        References:
            [1] Burgess, Christopher P., et al. "Understanding disentangling in
            $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
        """
        super().__init__()

        # Layer parameters
        hid_channels = 32
        kernel_size = 4
        hidden_dim = 256
        self.n_components = n_components
        self.img_size = img_size
        # Shape required to start transpose convs
        self.reshape = (hid_channels, kernel_size, kernel_size)
        n_channels = self.img_size[0]
        self.img_size = img_size

        # Fully connected layers
        self.lin1 = nn.Linear(n_components, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)
        self.lin3 = nn.Linear(hidden_dim, np.prod(self.reshape))

        # Convolutional layers
        cnn_kwargs = dict(stride=2, padding=1)
        # If input image is 64x64 do fourth convolution
        if self.img_size[1] == self.img_size[2] == 64:
            self.convT_64 = nn.ConvTranspose2d(
                hid_channels, hid_channels, kernel_size, **cnn_kwargs
            )

        self.convT1 = nn.ConvTranspose2d(
            hid_channels, hid_channels, kernel_size, **cnn_kwargs
        )
        self.convT2 = nn.ConvTranspose2d(
            hid_channels, hid_channels, kernel_size, **cnn_kwargs
        )
        self.convT3 = nn.ConvTranspose2d(
            hid_channels, n_channels, kernel_size, **cnn_kwargs
        )

    def forward(self, z):
        batch_size = z.size(0)

        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(z))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))
        x = x.view(batch_size, *self.reshape)

        # Convolutional layers with ReLu activations
        if self.img_size[1] == self.img_size[2] == 64:
            x = torch.relu(self.convT_64(x))
        x = torch.relu(self.convT1(x))
        x = torch.relu(self.convT2(x))
        # Sigmoid activation for final conv layer
        x_pred_mean = torch.sigmoid(self.convT3(x))

        return x_pred_mean  # (batch_size, n_channels, height, width)

    def log_prob(self, x_pred_mean: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """The log probability of the observation given the predicted observation.

        Parameters
        ----------
        x_pred_mean : torch.Tensor of shape (*, n_channels, height, width)
            The mean of the predicted observation.
        x : torch.Tensor of shape (*, n_channels, height, width)
            The observation.

        Returns
        -------
        torch.Tensor of shape (*,)
            The log probability of the observation given the predicted observation.
        """
        return -F.binary_cross_entropy(x_pred_mean, x, reduction="none").sum(
            dim=(-1, -2, -3)
        )


class BurgessEncoder(Encoder):
    def __init__(
        self,
        img_size: tuple[int, int, int],
        n_components: int,
        n_total_samples: int | None = None,
    ):
        """Encoder of the model proposed in [1].

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).

        Model Architecture (transposed for decoder)
        ------------
        - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
        - 2 fully connected layers (each of 256 units)
        - Latent distribution:
            - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)

        References:
            [1] Burgess, Christopher P., et al. "Understanding disentangling in
            $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
        """
        super().__init__(n_components, n_total_samples)

        # Layer parameters
        hid_channels = 32
        kernel_size = 4
        hidden_dim = 256
        self.img_size = img_size
        # Shape required to start transpose convs
        self.reshape = (hid_channels, kernel_size, kernel_size)
        n_channels = self.img_size[0]

        # Convolutional layers
        cnn_kwargs = dict(stride=2, padding=1)
        self.conv1 = nn.Conv2d(n_channels, hid_channels, kernel_size, **cnn_kwargs)
        self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
        self.conv3 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)

        # If input image is 64x64 do fourth convolution
        if self.img_size[1] == self.img_size[2] == 64:
            self.conv_64 = nn.Conv2d(
                hid_channels, hid_channels, kernel_size, **cnn_kwargs
            )

        # Fully connected layers
        self.lin1 = nn.Linear(np.prod(self.reshape), hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)

        # Fully connected layers for mean and variance
        self.fc_mean = nn.Linear(hidden_dim, self.n_components)

    def forward(self, x):
        batch_size = x.size(0)

        # Convolutional layers with ReLu activations
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        if self.img_size[1] == self.img_size[2] == 64:
            x = torch.relu(self.conv_64(x))

        # Fully connected layers with ReLu activations
        x = x.view((batch_size, -1))
        x = torch.relu(self.lin1(x))
        x = torch.relu(self.lin2(x))

        # Fully connected layer for log variance and mean
        # Log std-dev in paper (bear in mind)
        z_pred_mean = self.fc_mean(x)
        # return z_pred_mean, self.compute_log_std(z_pred_mean)
        return z_pred_mean, self.log_std[None, :].expand_as(z_pred_mean)
