"""Metric for the evidence lower bound (ELBO)."""

import torch

from calnf.datasets.dataset import Dataset
from calnf.metrics.metric import Metric


class ELBOMetric(Metric):
    def __init__(self, dataset: Dataset, n_elbo_particles: int) -> None:
        """Initialize the ELBO metric.

        Args:
            dataset (Dataset): Dataset.
            n_elbo_particles (int): Number of particles to use for the ELBO computation.
        """
        super().__init__()
        self.dataset = dataset
        self.n_elbo_particles = n_elbo_particles

    @torch.no_grad()
    def __call__(
        self,
        device: torch.device,
        dist: torch.distributions.Distribution,
        nominal_test_loader: torch.utils.data.DataLoader,
        target_test_loader: torch.utils.data.DataLoader,
    ) -> dict[str, float]:
        """Compute the metric.

        Args:
            device (torch.device): Device to use for the data.
            dist (torch.distributions.Distribution): Distribution over latent variables.
            nominal_test_loader (torch.utils.data.DataLoader): DataLoader for nominal
                test data.
            target_test_loader (torch.utils.data.DataLoader): DataLoader for target test
        """
        # Take the mean elbo over batches in the dataloader
        elbo_nominal = 0.0
        n_nominal_batches = 0

        for obs_nominal in nominal_test_loader:
            obs_nominal = obs_nominal.to(device)
            n_nominal = len(obs_nominal)

            for _ in range(self.n_elbo_particles):
                elbo_nominal += self.dataset.single_particle_elbo(
                    dist, n_nominal, obs_nominal
                )
                n_nominal_batches += 1

        elbo_nominal /= n_nominal_batches

        elbo_target = 0.0
        n_target_batches = 0

        for obs_target in target_test_loader:
            obs_target = obs_target.to(device)
            n_target = len(obs_target)

            for _ in range(self.n_elbo_particles):
                elbo_target += self.dataset.single_particle_elbo(
                    dist, n_target, obs_target
                )
                n_target_batches += 1

        elbo_target /= n_target_batches

        # Normalize by the number of latent dimensions
        elbo_nominal /= self.dataset.latent_dims
        elbo_target /= self.dataset.latent_dims

        return {
            "ELBO (nominal)": elbo_nominal.detach().cpu().item(),
            "ELBO (target)": elbo_target.detach().cpu().item(),
        }
