"""Guide for image generation based on Glow."""
import torch
import torch.nn as nn
import normflows as nf

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset


class ConditionalMultiscaleFlow(nn.Module):
    """
    Conditional version of nf.MultiscaleFlow.
    """

    def __init__(self, q0, flows, merges, transform=None):
        """Constructor

        Args:

          q0: List of base distribution
          flows: List of list of flows for each level
          merges: List of merge/split operations (forward pass must do merge)
          transform: Initial transformation of inputs
        """
        super().__init__()
        self.q0 = nn.ModuleList(q0)
        self.num_levels = len(self.q0)
        self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
        self.merges = torch.nn.ModuleList(merges)
        self.transform = transform

    def sample(self, num_samples=1, y=None, temperature=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          y: input for the base distribution and flows
          temperature: Temperature parameter for temp annealed sampling

        Returns:
          Samples, log probability
        """
        if temperature is not None:
            self.set_temperature(temperature)
        for i in range(len(self.q0)):
            z_, log_q_ = self.q0[i](num_samples, y)
            if i == 0:
                log_q = log_q_
                z = z_
            else:
                log_q += log_q_
                z, log_det = self.merges[i - 1]([z, z_])
                log_q -= log_det
            for flow in self.flows[i]:
                try:
                    z, log_det = flow(z, context=y)
                except TypeError:
                    # This module doesn't take context
                    z, log_det = flow(z)
                log_q -= log_det
        if self.transform is not None:
            z, log_det = self.transform(z)
            log_q -= log_det
        if temperature is not None:
            self.reset_temperature()
        return z, log_q

    def log_prob(self, x, y):
        """Get log probability for batch

        Args:
          x: Batch
          y: context for the base and flows

        Returns:
          log probability
        """
        log_q = 0
        z = x
        if self.transform is not None:
            z, log_det = self.transform.inverse(z)
            log_q += log_det
        for i in range(len(self.q0) - 1, -1, -1):
            for j in range(len(self.flows[i]) - 1, -1, -1):
                try:
                    z, log_det = self.flows[i][j].inverse(z, context=y)
                except TypeError:
                    # This module doesn't take context
                    z, log_det = self.flows[i][j].inverse(z)
                log_q += log_det
            if i > 0:
                [z, z_], log_det = self.merges[i - 1].inverse(z)
                log_q += log_det
            else:
                z_ = z

            log_q += self.q0[i].log_prob(z_, y)
        return log_q


class GlowDistribution(torch.distributions.Distribution):
    """
    Wraps a MultiscaleFlow as a torch.distributions.Distribution.

    Args:
        flow: MultiscaleFlow object.
        context: Context for the distribution. Must be shape [context_shape]
    """

    def __init__(self, flow, context, input_shape):
        self.model = flow
        self.input_shape = input_shape

        # Add a leading dimension to context if not present
        if len(context.shape) == 1:
            context = context.unsqueeze(0)
        self.context = context

    def log_prob(self, x):
        # Compute log probs for each context provided
        log_probs = []

        for context in self.context:
            # Extend the label to the batch dimension of x, if present
            x = x.reshape(-1, *self.input_shape)

            context = context.expand(x.shape[0], -1)

            log_prob = self.model.log_prob(x, y=context)

            # If there was no batch dimension, remove it
            if x.shape[0] == 1:
                log_prob = log_prob.squeeze(0)

            log_probs.append(log_prob)

        log_probs = torch.stack(log_probs)

        # Swap the context and data batch dimensions
        log_probs = log_probs.transpose(0, 1)

        # If there is only one row in context, remove the batch dimension
        if self.context.shape[0] == 1:
            log_probs = log_probs.squeeze(1)

        return log_probs

    def rsample_and_log_prob(self, sample_shape=torch.Size()):
        # Can only sample with one batch dimension
        if len(sample_shape) > 1:
            raise ValueError("Sample shape must have at most one dimension")

        # If the sample shape is empty, add a batch dimension
        if len(sample_shape) == 0:
            sample_shape = torch.Size([1])

        # Draw samples for each row in context
        samples, log_probs = [], []

        for context in self.context:
            # Extend the label to the sample shape
            context = context.expand(sample_shape[0], -1)

            # Sample (this is already reparemeterized, so it's differentiable)
            sample, log_prob = self.model.sample(y=context)

            # Reshape to a vector
            sample = sample.reshape(sample_shape[0], -1)

            samples.append(sample)
            log_probs.append(log_prob)

        samples = torch.stack(samples)
        log_probs = torch.stack(log_probs)

        # Swap the context and data batch dimensions
        samples = samples.transpose(0, 1)
        log_probs = log_probs.transpose(0, 1)

        # If there is only one row in context, remove the context batch dimension
        if self.context.shape[0] == 1:
            samples = samples.squeeze(1)
            log_probs = log_probs.squeeze(1)

        # If there was no data batch dimension, remove it
        if sample_shape[0] == 1:
            samples = samples.squeeze(0)
            log_probs = log_probs.squeeze(0)

        return samples, log_probs

    def rsample(self, sample_shape=torch.Size()):
        return self.rsample_and_log_prob(sample_shape)[0]


class Glow(Guide):
    """
    A normalizing flow for image generation based on the Glow architecture

    Glow paper: https://arxiv.org/abs/1807.03039
    Reference implementation: https://github.com/VincentStimper/normalizing-flows

    Args:
        device: Device to use for the guide.
        dataset: Dataset object.
        input_shape: Shape of the input images (channel, width, height).
        context_size: Size of the context vector.
        hidden_channels: Number of hidden channels in the flow.
        L: Number of layers.
        K: Number of blocks per layer.
        grad_clip: Gradient clipping value.
    """

    def __init__(
        self,
        device,
        dataset,
        input_shape,
        context_size=5,
        hidden_channels=128,
        L=2,
        K=8,
        grad_clip=1.0,
    ):
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.input_shape = input_shape
        self.context_size = context_size

        # Define the flow
        channels = input_shape[0]
        split_mode = "channel"
        scale = True
        q0 = []
        merges = []
        flows = []
        for i in range(L):
            flows_ = []
            for j in range(K):
                flows_ += [
                    nf.flows.GlowBlock(
                        channels * 2 ** (L + 1 - i),
                        hidden_channels,
                        split_mode=split_mode,
                        scale=scale,
                    )
                ]
            flows_ += [nf.flows.Squeeze()]
            flows += [flows_]
            if i > 0:
                merges += [nf.flows.Merge()]
                latent_shape = (
                    input_shape[0] * 2 ** (L - i),
                    input_shape[1] // 2 ** (L - i),
                    input_shape[2] // 2 ** (L - i),
                )
            else:
                latent_shape = (
                    input_shape[0] * 2 ** (L + 1),
                    input_shape[1] // 2**L,
                    input_shape[2] // 2**L,
                )
            q0 += [
                nf.distributions.ClassCondDiagGaussian(latent_shape, self.context_size)
            ]

        self.model = nf.MultiscaleFlow(q0, flows, merges).to(self.device)
        # self.model = ConditionalMultiscaleFlow(q0, flows, merges).to(self.device)

        self.label = torch.nn.Parameter(torch.ones(self.context_size, device=device))

    def flow(self, label):
        """Return the flow for the given label."""
        return GlowDistribution(self.model, label.to(self.device), self.input_shape)

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """

        # Loss is the negative ELBO on both nominal and target data
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal.to(self.device)
        )
        target_loss = -dataset.single_particle_elbo(
            self.target_distribution(), n_target, obs_target.to(self.device)
        )

        loss = nominal_loss + target_loss
        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "target_loss": target_loss.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return GlowDistribution(
            self.model, torch.zeros(self.context_size).to(self.device), self.input_shape
        )

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return GlowDistribution(self.model, self.label, self.input_shape)
