import torch
import torch.nn as nn
from torch.nn import functional as F


EPS = 1e-12
DEFAULT_LATENT_SPEC_MNIST = {
    'cont': 10,
    'disc': [10]
}

class JointVAE(nn.Module):
    def __init__(self, input_dim: int = 1024, latent_spec: dict = DEFAULT_LATENT_SPEC_MNIST, temperature=.67, hidden_dim: int = 512, channel_spec: list = [32, 64]):
        """
        Convolutional implementation of JointVAE, but with a 1x1 Kernel for use 
        with tabular data.
        Based on the implementation provided here:
        https://github.com/Schlumberger/joint-vae

        Parameters
        ----------
        input_dim : int
            Dimension of the input data. Will also be used as decoder output
            dimension. The data is assumed to be flattened to N x 1 x 1 x input_dim
            where N is the batch size.

        latent_spec : dict
            Specifies latent distribution. For example:
            {'cont': 10, 'disc': [10, 4, 3]} encodes 10 normal variables and
            3 gumbel softmax variables of dimension 10, 4 and 3. A latent spec
            can include both 'cont' and 'disc' or only 'cont' or only 'disc'.

        temperature : float
            Temperature for gumbel softmax distribution.
        """

        super(JointVAE, self).__init__()
        assert len(channel_spec) == 2, "The definition of the hidden layers number of channels must be of length 2."

        # Parameters
        self.use_cuda = torch.cuda.is_available() # Needed here in the forward pass definition.
        self.input_dim = input_dim
        self.is_continuous = 'cont' in latent_spec
        self.is_discrete = 'disc' in latent_spec
        self.latent_spec = latent_spec
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.hidden_channels_1 = channel_spec[0]
        self.hidden_channels_2 = channel_spec[1]
        self.reshape = (self.hidden_channels_2, self.input_dim) # 1, self.input_dim for 4, 4 Shape required to start transpose convs

        # Calculate dimensions of latent distribution
        self.latent_cont_dim = 0
        self.latent_disc_dim = 0
        self.num_disc_latents = 0
        if self.is_continuous:
            self.latent_cont_dim = self.latent_spec['cont']
        if self.is_discrete:
            self.latent_disc_dim += sum([dim for dim in self.latent_spec['disc']])
            self.num_disc_latents = len(self.latent_spec['disc'])
        self.latent_dim = self.latent_cont_dim + self.latent_disc_dim

        # Define encoder layers
        # Intial layer
        encoder_layers = [
            nn.Conv1d(1, self.hidden_channels_1, kernel_size=1),
            nn.ReLU()
        ]

        # Add final layers
        encoder_layers += [
            nn.Conv1d(self.hidden_channels_1, self.hidden_channels_2, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(self.hidden_channels_2, self.hidden_channels_2, kernel_size=1),
            nn.ReLU()
        ]

        # Define encoder
        self.img_to_features = nn.Sequential(*encoder_layers)

        # Map encoded features into a hidden vector which will be used to
        # encode parameters of the latent distribution
        self.features_to_hidden = nn.Sequential(
            nn.Linear(self.hidden_channels_2 * 1 * self.input_dim, self.input_dim), # The last to factors is the remaining image size.
            nn.ReLU(),
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU()
        )

        # Encode parameters of latent distribution
        if self.is_continuous:
            self.fc_mean = nn.Linear(self.hidden_dim, self.latent_cont_dim)
            self.fc_log_var = nn.Linear(self.hidden_dim, self.latent_cont_dim)
        if self.is_discrete:
            # Linear layer for each of the categorical distributions
            fc_alphas = []
            for disc_dim in self.latent_spec['disc']:
                fc_alphas.append(nn.Linear(self.hidden_dim, disc_dim))
            self.fc_alphas = nn.ModuleList(fc_alphas)

        # Map latent samples to features to be used by generative model
        self.latent_to_features = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.input_dim),
            nn.ReLU(),
            nn.Linear(self.input_dim, self.hidden_channels_2 * 1 * self.input_dim),
            nn.ReLU()
        )

        # Define decoder
        decoder_layers = []

        decoder_layers += [
            nn.ConvTranspose1d(self.hidden_channels_2, self.hidden_channels_1, kernel_size=1),
            nn.ReLU(),
            nn.ConvTranspose1d(self.hidden_channels_1, self.hidden_channels_1, kernel_size=1),
            nn.ReLU(),
            nn.ConvTranspose1d(self.hidden_channels_1, 1, kernel_size=1),
            nn.Tanh()
        ]

        # Define decoder
        self.features_to_img = nn.Sequential(*decoder_layers)

        
    def encode(self, x: torch.Tensor):
        """
        Encodes an image into parameters of a latent distribution defined in
        self.latent_spec.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data, shape (N, C, H, W)
        """

        batch_size = x.size()[0]

        # Encode image to hidden features
        features = self.img_to_features(x)
        hidden = self.features_to_hidden(features.view(batch_size, -1))

        # Output parameters of latent distribution from hidden representation
        latent_dist = {}

        if self.is_continuous:
            latent_dist['cont'] = [self.fc_mean(hidden), self.fc_log_var(hidden)]

        if self.is_discrete:
            latent_dist['disc'] = []
            for fc_alpha in self.fc_alphas:
                latent_dist['disc'].append(F.softmax(fc_alpha(hidden), dim=1))

        return latent_dist

    
    def reparameterize(self, latent_dist: dict, device: torch.device = None) -> torch.Tensor:
        """
        Samples from latent distribution using the reparameterization trick.

        Parameters
        ----------
        latent_dist : dict
            Dict with keys 'cont' or 'disc' or both, containing the parameters
            of the latent distributions as torch.Tensor instances.
        """

        latent_sample = []

        if self.is_continuous:
            mean, logvar = latent_dist['cont']
            cont_sample = self.sample_normal(mean, logvar, device=device)
            latent_sample.append(cont_sample)

        if self.is_discrete:
            for alpha in latent_dist['disc']:
                disc_sample = self.sample_gumbel_softmax(alpha, device=device)
                latent_sample.append(disc_sample)

        # Concatenate continuous and discrete samples into one large sample
        return torch.cat(latent_sample, dim=1)

    
    def sample_normal(self, mean: torch.Tensor, logvar: torch.Tensor, device: torch.device = None) -> torch.Tensor:
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (N, D) where D is dimension
            of distribution.

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (N, D)
        """
        
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.zeros(std.size()).normal_()
            if self.use_cuda:
                eps = eps.to(device)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

        
    def sample_gumbel_softmax(self, alpha: torch.Tensor, device: torch.device = None) -> torch.Tensor:
        """
        Samples from a gumbel-softmax distribution using the reparameterization
        trick.

        Parameters
        ----------
        alpha : torch.Tensor
            Parameters of the gumbel-softmax distribution. Shape (N, D)
        """
        
        if self.training:
            # Sample from gumbel distribution
            unif = torch.rand(alpha.size())
            if self.use_cuda:
                unif = unif.to(device)
            gumbel = -torch.log(-torch.log(unif + EPS) + EPS)
            # Reparameterize to create gumbel softmax sample
            log_alpha = torch.log(alpha + EPS)
            logit = (log_alpha + gumbel) / self.temperature
            return F.softmax(logit, dim=1)
        else:
            # In reconstruction mode, pick most likely sample
            _, max_alpha = torch.max(alpha, dim=1)
            one_hot_samples = torch.zeros(alpha.size())
            # On axis 1 of one_hot_samples, scatter the value 1 at indices
            # max_alpha. Note the view is because scatter_ only accepts 2D
            # tensors.
            one_hot_samples.scatter_(1, max_alpha.view(-1, 1).data.cpu(), 1)
            if self.use_cuda:
                one_hot_samples = one_hot_samples.to(device)
            return one_hot_samples

        
    def decode(self, latent_sample: torch.Tensor) -> torch.Tensor:
        """
        Decodes sample from latent distribution into an image.

        Parameters
        ----------
        latent_sample : torch.Tensor
            Sample from latent distribution. Shape (N, L) where L is dimension
            of latent distribution.
        """
        
        features = self.latent_to_features(latent_sample)
        return self.features_to_img(features.view(-1, *self.reshape))

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (N, C, Input Dim.)
        """
        
        latent_dist = self.encode(x)
        latent_sample = self.reparameterize(latent_dist, device=x.device)
        return self.decode(latent_sample), latent_dist
