"""Normalizing flow guide with W2 regularization."""
import torch
import zuko

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset


class W2RegularizedFlow(Guide):
    def __init__(self, device, dataset, reg_penalty=1.0, grad_clip=1.0):
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.context_size = 5
        self.flow = zuko.flows.CNF(
            features=dataset.latent_dims,
            context=self.context_size,
            hidden_features=(64, 64),
        ).to(device)
        self.reg_penalty = reg_penalty

    def w2_regularization(self, n_particles: int = 100) -> torch.Tensor:
        # Get the square norm of the flow
        ode_dim = self.flow.transform.ode[0].in_features
        samples = torch.randn(n_particles, ode_dim).to(self.device)
        mean_square_flow = torch.mean(
            (self.flow.transform.ode(samples) ** 2).sum(dim=-1)
        )
        return mean_square_flow

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """

        # Loss is the negative ELBO on both nominal and target data
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal
        )
        target_loss = -dataset.single_particle_elbo(
            self.target_distribution(), n_target, obs_target
        )

        # We regularize with the KL divergence between the two distributions
        w2_regularization_term = self.w2_regularization()

        loss = nominal_loss + target_loss + self.reg_penalty * w2_regularization_term
        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "target_loss": target_loss.detach().cpu().item(),
            "w2_regularization_term": w2_regularization_term.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return self.flow(torch.zeros(self.context_size, device=self.device))

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return self.flow(torch.ones(self.context_size, device=self.device))
