"""Abstract base class for variational guides."""
from abc import ABC, abstractmethod

import torch
from torch import nn

from calnf.datasets.dataset import Dataset


class Guide(ABC, nn.Module):
    def __init__(self, grad_clip=1.0) -> None:
        super().__init__()
        self.grad_clip = grad_clip

    def configure_optimizer(self, lr: float = 1e-3) -> torch.optim.Optimizer:
        """Return the optimizer for training.

        Defaults to Adam.
        """
        return torch.optim.Adam(self.parameters(), lr=lr)

    def clip_grad_norm(self) -> dict[str, torch.Tensor]:
        """Clip the gradients of all parameters."""
        grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)

        return {"grad_norm": grad_norm}

    @abstractmethod
    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """
        pass

    @abstractmethod
    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        pass

    @abstractmethod
    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        pass

    def save(self, filepath: str) -> None:
        """Save the guide to a file.

        Args:
            filepath (str): Path to the file.
        """
        torch.save(self.state_dict(), filepath)
