"""Adaptive version of calibrated normalizing flow."""
import torch
import zuko

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset
from calnf.utils.divergence import kl_divergence


class BalancedCalibratedFlow(Guide):
    def __init__(
        self,
        device,
        dataset,
        num_subsamples: int = 5,
        reg_penalty: float = 1.0,
        calibrate_label: bool = True,
        num_kl_particles: int = 10,
        grad_clip: float = 1.0,
    ):
        """Initialize the calibrated normalizing flow.

        Args:
            device (torch.device): Device to use for the guide.
            dataset (Dataset): Dataset object.
            num_subsamples (int): Number of subsamples to use for calibration.
            reg_penalty (float): the strength of the regularization term.
            calibrate_label (bool): if False, do not update the calibration label
            num_kl_particles (int): Number of particles to use for the KL divergence.
            grad_clip (float): Gradient clipping value.
        """
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.model = zuko.flows.NSF(
            features=dataset.latent_dims,
            context=num_subsamples,
            hidden_features=(64, 64),
        ).to(device)

        self.num_subsamples = num_subsamples
        self.calibration_label = torch.nn.Parameter(
            torch.ones(num_subsamples, device=device) / num_subsamples
        )
        self.calibration_label.requires_grad_(calibrate_label)
        self.target_subsamples = None  # a place to store the target subsample indices
        self.labels = torch.eye(num_subsamples).to(device)
        self.reg_penalty = reg_penalty
        self.num_kl_particles = num_kl_particles

    def flow(self, context):
        return self.model(context)

    def clip_grad_norm(self) -> dict[str, torch.Tensor]:
        """Clip the gradients of all parameters."""
        flow_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        label_grad_norm = torch.nn.utils.clip_grad_norm_([self.calibration_label], 1.0)

        return {
            "flow_grad_norm": flow_grad_norm,
            "label_grad_norm": label_grad_norm,
        }

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """
        obs_nominal = obs_nominal.to(self.device)
        obs_target = obs_target.to(self.device)

        # If we have not yet created the subsamples, or if we are resampling, do that
        if self.target_subsamples is None:
            self.target_subsamples = []

            for _ in range(0, self.num_subsamples):
                # Sample n_target samples with replacement
                self.target_subsamples.append(
                    torch.randint(n_target, (n_target,), device=self.device)
                )

        # Loss is the negative ELBO on nominal data plus the calibration loss
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal
        )

        # Train on each target subsample
        calibration_loss = torch.tensor(0.0).to(self.device)
        for i, subsample in enumerate(self.target_subsamples):
            target_subsample_loss = -dataset.single_particle_elbo(
                self.flow(self.labels[i]), len(subsample), obs_target[subsample]
            )
            calibration_loss += target_subsample_loss / self.num_subsamples

        # Also add the ELBO on the full target set with the calibrated label
        # (only use this to optimize the label, not the flow)
        self.model.requires_grad_(False)
        target_loss = -dataset.single_particle_elbo(
            self.target_distribution(), n_target, obs_target
        )
        self.model.requires_grad_(True)

        # Also add regularization between random subsamples
        regularization = torch.tensor(0.0).to(self.device)
        pairs = torch.combinations(torch.arange(self.num_subsamples))
        # pairs = pairs[torch.randperm(pairs.shape[0])[: self.num_subsamples]]
        pairs = pairs[torch.randperm(pairs.shape[0])[:1]]
        pairs = pairs.to(self.device)
        for i, pair in enumerate(pairs):
            dist_a = self.flow(self.labels[pair[0]])
            dist_b = self.flow(self.labels[pair[1]])
            kl = kl_divergence(dist_a, dist_b, num_particles=self.num_kl_particles)

            if ~(torch.isnan(kl) | torch.isinf(kl)):
                regularization += kl

        loss = (
            nominal_loss
            + calibration_loss
            + target_loss
            + self.reg_penalty * regularization
        )
        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "target_loss": target_loss.detach().cpu().item(),
            "calibration_loss": calibration_loss.detach().cpu().item(),
            "regularization": regularization.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return self.flow(torch.zeros(self.num_subsamples, device=self.device))

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return self.flow(self.calibration_label)
