import logging
from typing import Optional

import pytorch_lightning as pl
import torch
import torch.distributions as dist
import torch.nn as nn

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)

LARGE_NEGATIVE = -10000000.


def pairwise_label_mask(m: torch.Tensor) -> torch.Tensor:
    """
    Convert a batch_size x n_classes tensor, where each row is a one-hot encoding of
    the class label, to a batch_size x batch_size symmetric tensor indciating whether
    the [i, j]th objects in the batch are in the same class.
    """
    return torch.einsum('ik,jk->ij', m, m)


def make_subset_penalty_v2(
    posterior,
    categories,
    latent,
    penalty_scale=1.0,
    entropy_relative_scale=1.0,
    penalty_exp_factor=1.0,
):
    """Subset penalty

    Args:
        posterior (torch distribution): The q(z|x) distribution with shape=(batch_size, num_features)
        categories (torch Tensor): The categories tensor with shape=(batch_size, 2).
            The rows are one hot vectors.
        penalty_scale (float): The positive penalty scale factor
        entropy_relative_scale (float): The additional relative factor for the entropy term. This
            is equal to 1 if the term is balanced
        penalty_exp_factor (float): The exponent that scales each probability. This is done for each
            sample, rather than the final log probability (design choice). Value of 1 is
            ~l2 penalty; 0.5 is ~l1 penalty.

    Returns:
        (torch Tensor): The penalty scalar.
    """
    # Computes the cross product of log probs, shape B x B, (i.e. batch sample x source_distribution)
    batch_size = latent.shape[0]
    log_probs = posterior.log_prob(latent.unsqueeze(1)) * penalty_exp_factor
    mask_temp = pairwise_label_mask(categories)
    mask_own = mask_temp - torch.eye(batch_size, device=latent.device)
    mask_other = 1.0 - mask_temp
    mixture_size_own = mask_own.sum(-1)
    mixture_size_other = batch_size - 1 - mixture_size_own

    log_prob_own = (mask_own * log_probs + (1 - mask_own) * LARGE_NEGATIVE).logsumexp(
        -1
    ) - torch.log(mixture_size_own)
    log_prob_other = (mask_other * log_probs + (1 - mask_other) * LARGE_NEGATIVE).logsumexp(
        -1
    ) - torch.log(mixture_size_other)

    contrastive_penalty = (
        penalty_scale * (entropy_relative_scale * log_prob_own - log_prob_other).mean()
    )
    return contrastive_penalty

def groupwise_subset_penalty(posterior: dist.Distribution,
                             categories: torch.Tensor,
                             groups: torch.Tensor,
                             latents: torch.Tensor,
                             penalty_scale=1.0,
                             entropy_relative_scale=1.0,
                             penalty_exp_factor=1.0) -> float:
    # Insert a new axis into latents to broadcast up the log_prob calcs to return
    # a batch_size x batch_size tensor:
    #     log_probs[i, j] = q(z[i]|x[j])
    log_probs = posterior.log_prob(latents.unsqueeze(1)) * penalty_exp_factor
    batch_size = log_probs.shape[0]
    device = latents.device
    cat_mask = pairwise_label_mask(categories)
    group_mask = pairwise_label_mask(groups)
    mask_own = group_mask * (cat_mask - torch.eye(batch_size, device=device))
    mask_other = group_mask * (1. - cat_mask)
    mixture_size_own = mask_own.sum(-1)
    mixture_size_other = mask_other.sum(-1)
    # Take log(1 + size) for the mixture the mixture sizes to handle batch elements whose own and other
    # classes are empty in the batch.
    log_prob_own = (mask_own * log_probs + (1 - mask_own) * LARGE_NEGATIVE).logsumexp(-1) - torch.log(1.0 + mixture_size_own)
    log_prob_other = (mask_other * log_probs + (1 - mask_other) * LARGE_NEGATIVE).logsumexp(-1) - torch.log(1.0 + mixture_size_other)

    contrastive_penalty = (
        penalty_scale * (entropy_relative_scale * log_prob_own - log_prob_other).mean()
    )
    return contrastive_penalty


class ContrastiveCVAE(pl.LightningModule):
    """Conditional VAE with flexible encoder and decoder"""
    def __init__(
        self,
        encoder,
        decoder,
        z_dim,
        n_groups,
        forward_pass_use_groups=False,
        penalty_scale=1.0,
        entropy_relative_scale=1.0,
        penalty_exp_factor=1.0,
        learning_rate=0.001,
        gamma=1.0,
        beta=1.0,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.z_dim = z_dim
        # n_groups is included mostly as a sanity check that the expected number
        # of tensors are set up in each batch.
        assert n_groups is None or n_groups > 0
        self.n_groups = n_groups
        if forward_pass_use_groups:
            assert n_groups > 0
        self.forward_pass_use_groups = forward_pass_use_groups
        self.penalty_scale = penalty_scale
        self.entropy_relative_scale = entropy_relative_scale
        self.penalty_exp_factor = penalty_exp_factor
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.beta = beta

        LOGGER.info("Using penalty scale: %f", self.penalty_scale)

        # Set the encoder and decoder networks
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x: torch.Tensor, c: torch.Tensor, d: Optional[torch.Tensor]):
        """Compute the approximate posterior for input data x

        Returns:
            an instance of torch.distributions.Normal representing the variational  posterior q(z | x)
        """
        assert d is None or (self.n_groups is not None and self.n_groups > 0)
        if self.forward_pass_use_groups:
            assert d is not None
            encoder_input = [x, c, d]
        else:
            encoder_input = [x, c]
        return self.encoder(torch.cat(encoder_input, dim=-1))[0]

    def forward_decoder(self, z, c, d=None):
        assert d is None or (self.n_groups is not None and self.n_groups > 0)
        if self.forward_pass_use_groups:
            assert d is not None
            inputs = [z, c, d]
        else:
            inputs = [z, c]
        return self.decoder(torch.cat(inputs, dim=-1))

    def _compute_elbo(self,
                      x: torch.Tensor,
                      categories: torch.Tensor,
                      qz: dist.Distribution,
                      qz_penalty: dist.Distribution,
                      groups: Optional[torch.Tensor]):
        latents = qz.rsample()
        z_prior = dist.Independent(dist.Normal(torch.zeros_like(latents), torch.ones_like(latents)), 1)
        x_hat = self.forward_decoder(latents, categories, groups)

        reconstruction = x_hat.log_prob(x)
        if reconstruction.dim() == 2:
            reconstruction = reconstruction.sum(axis=1)
        kl = dist.kl.kl_divergence(qz, z_prior)
        elbo = (reconstruction - self.beta * kl).mean()

        if groups is not None:
            penalty = groupwise_subset_penalty(
                qz_penalty, categories, groups, latents, self.penalty_scale, self.entropy_relative_scale, self.penalty_exp_factor
            )
        else:
            penalty = make_subset_penalty_v2(
                qz_penalty, categories, latents, self.penalty_scale, self.entropy_relative_scale, self.penalty_exp_factor
            )
        elbo -= penalty
        return elbo, penalty

    def _execute_step(self, batch, batch_idx, stage):
        assert ((self.n_groups is not None and len(batch) == 3) or
                (self.n_groups is None and len(batch) == 2))
        if self.n_groups is not None:
            assert len(batch) == 3, f'Expected batch of [x, categories, groups], got {[x.shape for x in batch]}'
            x, categories, groups = batch
        else:
            assert len(batch) == 2, f'Expected batch of [x, categories], got {[x.shape for x in batch]}'
            x, categories = batch
            groups = None
        if self.forward_pass_use_groups:
            enc_inputs = batch
        else:
            enc_inputs = batch[:2]
        qz_temps = self.encoder(torch.cat(enc_inputs, dim=-1))
        qz = dist.Independent(qz_temps[0], 1)
        qz_penalty = dist.Independent(qz_temps[1], 1)
        elbo, penalty = self._compute_elbo(x, categories=categories, qz=qz, qz_penalty=qz_penalty, groups=groups)
        self.log(f'{stage}_elbo', elbo)
        self.log(f'{stage}_loss', -elbo)
        self.log(f'{stage}_penalty', penalty)
        return {'loss': -elbo, 'elbo': elbo, 'penalty': penalty}

    def training_step(self, batch, batch_idx):
        """Execute one training step"""
        return self._execute_step(batch, batch_idx, stage='train')

    def validation_step(self, batch, batch_idx):
        """Execute one validation step"""
        return self._execute_step(batch, batch_idx, stage='valid')

    def configure_optimizers(self):
        """Fetch the optimiser parameters - required by PyTorch Lightning"""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma)
        return [optimizer], [scheduler]
