"""Normalizing flow guide with KL regularization."""
import torch
import zuko

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset
from calnf.utils.divergence import kl_divergence


class KLRegularizedFlow(Guide):
    def __init__(self, device, dataset, reg_penalty=1.0, grad_clip=1.0):
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.context_size = 5
        self.flow = zuko.flows.NSF(
            features=dataset.latent_dims,
            context=self.context_size,
            hidden_features=(64, 64),
        ).to(device)
        self.reg_penalty = reg_penalty

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """

        # Loss is the negative ELBO on both nominal and target data
        target_loss = -dataset.single_particle_elbo(
            self.target_distribution(), n_target, obs_target
        )
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal
        )

        # We regularize with the KL divergence between the two distributions
        kl = kl_divergence(
            self.target_distribution(), self.nominal_distribution(), num_particles=100
        )

        loss = nominal_loss + target_loss + self.reg_penalty * kl

        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "target_loss": target_loss.detach().cpu().item(),
            "kl_divergence": kl.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return self.flow(torch.zeros(self.context_size, device=self.device))

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return self.flow(torch.ones(self.context_size, device=self.device))
