"""Calibrated normalizing flow."""
import torch
import zuko

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset


class CalibratedFlow(Guide):
    def __init__(
        self,
        device,
        dataset,
        num_subsamples: int = 5,
        resample: bool = False,
        calibrate_label: bool = True,
        grad_clip: float = 1.0,
    ):
        """Initialize the calibrated normalizing flow.

        Args:
            device (torch.device): Device to use for the guide.
            dataset (Dataset): Dataset object.
            num_subsamples (int): Number of subsamples to use for calibration.
            resample (bool): Whether to resample the target subsamples at each iteration.
            calibrate_label (bool): if False, do not update the calibration label
            grad_clip (float): Gradient clipping value.
        """
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.num_subsamples = num_subsamples
        self.calibration_label = torch.nn.Parameter(
            torch.ones(num_subsamples, device=device) / num_subsamples
        )
        self.calibration_label.requires_grad_(calibrate_label)
        self.model = zuko.flows.NSF(
            features=dataset.latent_dims,
            context=num_subsamples,
            hidden_features=(64, 64),
        ).to(device)
        self.resample = resample
        self.target_subsamples = None  # a place to store the target subsample indices
        self.labels = torch.eye(num_subsamples).to(device)

    def flow(self, context):
        return self.model(context)

    def clip_grad_norm(self) -> dict[str, torch.Tensor]:
        """Clip the gradients of all parameters."""
        flow_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        label_grad_norm = torch.nn.utils.clip_grad_norm_([self.calibration_label], 1.0)

        return {
            "flow_grad_norm": flow_grad_norm,
            "label_grad_norm": label_grad_norm,
        }

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """
        # If we have not yet created the subsamples, or if we are resampling, do that
        if self.target_subsamples is None or self.resample:
            self.target_subsamples = []

            for _ in range(0, self.num_subsamples):
                # self.target_subsamples.append(
                #     torch.randperm(n_target).to(self.device)[: n_target // 2]
                # )
                # Sample n_target samples with replacement
                self.target_subsamples.append(
                    torch.randint(n_target, (n_target,), device=self.device)
                )

        # Loss is the negative ELBO on nominal data plus the calibration loss
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal
        )

        calibration_loss = torch.tensor(0.0).to(self.device)
        for i, subsample in enumerate(self.target_subsamples):
            target_subsample_loss = -dataset.single_particle_elbo(
                self.flow(self.labels[i]), len(subsample), obs_target[subsample]
            )
            calibration_loss += target_subsample_loss / self.num_subsamples

        # Also add the ELBO on the full target set with the calibrated label
        target_loss = -dataset.single_particle_elbo(
            self.target_distribution(), n_target, obs_target
        )

        loss = nominal_loss + calibration_loss + target_loss
        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "target_loss": target_loss.detach().cpu().item(),
            "calibration_loss": calibration_loss.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return self.flow(torch.zeros(self.num_subsamples, device=self.device))

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return self.flow(self.calibration_label)
