"""Abstract base class for datasets."""
from abc import ABC, abstractmethod

import torch


class Dataset(ABC):
    @property
    @abstractmethod
    def obs_dims(self):
        """Return the number of observation dimensions."""
        pass

    @property
    @abstractmethod
    def latent_dims(self):
        """Return the number of latent dimensions."""
        pass

    @property
    def score_scale(self):
        """Return the scaling from ELBO to score function."""
        return 1.0

    @abstractmethod
    def configure_nominal_data(self, device):
        """Configure the nominal dataloader.

        This method must create two DataLoader instances and save them as attributes:
            - self.nominal_train_loader: DataLoader for training nominal data.
            - self.nominal_test_loader: DataLoader for test nominal data.

        Args:
            device (torch.device): Device to use for the data.
        """
        pass

    @abstractmethod
    def configure_target_data(self, device):
        """Configure the target dataloader.

        This method must create two DataLoader instances and save them as attributes:
            - self.target_train_loader: DataLoader for training target data.
            - self.target_test_loader: DataLoader for test target data.

        Args:
            device (torch.device): Device to use for the data.
        """
        pass

    @abstractmethod
    def single_particle_elbo(
        self, dist: torch.distributions.Distribution, n: int, obs: torch.Tensor
    ) -> torch.Tensor:
        """Compute the ELBO for a single particle.

        Args:
            dist (torch.distributions.Distribution): Distribution over latent variables.
            n (int): Number of observations.
            obs (torch.Tensor): Observations.

        Returns:
            torch.Tensor: ELBO for a single particle.
        """
        pass

    @abstractmethod
    def visualize(
        self,
        nominal_dist: torch.distributions.Distribution,
        target_dist: torch.distributions.Distribution,
    ):
        """Visualize the distribution over latent variables.

        Should save to wandb without committing and close the figure.

        Args:
            nominal_dist (torch.distributions.Distribution): Distribution over latent variables.
            target_dist (torch.distributions.Distribution): Distribution over latent variables.
        """
        pass
