"""Metrics for classifiers."""

import torch
import torchmetrics.classification as cm

from calnf.metrics.metric import Metric


class ClassifierMetrics(Metric):
    def __init__(
        self,
        dataset,
        n_particles: int = 10,
    ) -> None:
        """Initialize the classifier metrics."""
        super().__init__()
        self.dataset = dataset
        self.n_particles = n_particles

    def _get_scores(
        self, dist: torch.distributions.Distribution, n: int, obs: torch.Tensor
    ) -> torch.Tensor:
        """Get the scores for the observations.

        Args:
            dist (torch.distributions.Distribution): Distribution over latent variables.
            n (int): Number of observations.
            obs (torch.Tensor): Observations.

        Returns:
            torch.Tensor: Scores for the observations.
        """
        scores = []

        for i in range(n):
            elbo = torch.tensor(0.0).to(obs.device)
            for _ in range(self.n_particles):
                elbo += (
                    self.dataset.single_particle_elbo(dist, 1, obs[i].unsqueeze(0))
                    / self.n_particles
                )

            scores.append(elbo * self.dataset.score_scale)

        return torch.stack(scores)

    @torch.no_grad()
    def __call__(
        self,
        device: torch.device,
        dist: torch.distributions.Distribution,
        nominal_test_loader: torch.utils.data.DataLoader,
        target_test_loader: torch.utils.data.DataLoader,
    ) -> dict[str, float]:
        """Compute the metric.

        Args:
            device (torch.device): Device to use for the data.
            dist (torch.distributions.Distribution): Distribution over latent variables.
            nominal_test_loader (torch.utils.data.DataLoader): DataLoader for nominal
                test data.
            target_test_loader (torch.utils.data.DataLoader): DataLoader for target test
        """
        # Get scores for the nominal and target data
        nominal_scores = torch.tensor([]).to(device)
        target_scores = torch.tensor([]).to(device)

        for obs_nominal in nominal_test_loader:
            obs_nominal = obs_nominal.to(device)
            n_nominal = len(obs_nominal)
            nominal_scores = torch.cat(
                [nominal_scores, self._get_scores(dist, n_nominal, obs_nominal)]
            )

        for obs_target in target_test_loader:
            obs_target = obs_target.to(device)
            n_target = len(obs_target)
            target_scores = torch.cat(
                [target_scores, self._get_scores(dist, n_target, obs_target)]
            )

        # Create the labels and concatenate nominal + target into single tensors
        # so we can compute the metrics
        nominal_labels = (
            torch.zeros(len(nominal_scores)).to(nominal_scores.device).to(torch.int32)
        )
        target_labels = (
            torch.ones(len(target_scores)).to(target_scores.device).to(torch.int32)
        )
        labels = torch.cat([nominal_labels, target_labels])
        scores = torch.cat([nominal_scores, target_scores])

        # Compute the metrics
        aucroc = cm.BinaryAUROC().to(scores.device)(scores, labels)
        auprc = cm.BinaryAveragePrecision().to(scores.device)(scores, labels)

        # Compute the optimal decision threshold using the receiver operating curve
        fpr, tpr, thresholds = cm.BinaryROC().to(scores.device)(scores, labels)
        optimal_idx = torch.argmax(tpr - fpr)
        optimal_threshold = thresholds[optimal_idx]

        precision = cm.BinaryPrecision(threshold=optimal_threshold.item()).to(
            scores.device
        )(scores, labels)
        recall = cm.BinaryRecall(threshold=optimal_threshold.item()).to(scores.device)(
            scores, labels
        )

        metrics = {
            "AUROC": aucroc.detach().cpu().item(),
            "AUPRC": auprc.detach().cpu().item(),
            "Precision": precision.detach().cpu().item(),
            "Recall": recall.detach().cpu().item(),
            "Optimal Threshold": optimal_threshold.detach().cpu().item(),
        }

        return metrics
