from abc import ABC, abstractmethod
from typing import Any

import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GenerativeModel(ABC, nn.Module):
    """Abstract base class for Generative Models."""

    def __init__(self) -> None:
        """Initialize the Generative Model."""
        super().__init__()

    @abstractmethod
    @torch.no_grad()
    def sample_t(self, data: Batch) -> Float[torch.Tensor, "n_nodes"]:
        """Abstract method to sample random time in [0, 1] for each graph.

        :param data: PyG batch object.
        :return: Time of each node.
        """
        pass

    @abstractmethod
    def forward(self, data: Batch, use_ema: bool = False) -> Float[Tensor, "n_edges"]:
        """Abstract method to perform a forward pass with generative model during training.

        :param data: PyG batch object.
        """
        pass

    @abstractmethod
    def loss(
        self,
        data: Batch,
        pred: Float[torch.Tensor, "n_edges"],
        target: Float[torch.Tensor, "n_edges"],
    ) -> Float[torch.Tensor, "1"]:
        """Abstract method to compute loss of respective generative framework.

        :param data: PyG batch object.
        :param pred: Predicted edge probabilities.
        :param target: One hot encoded target edges.
        :returns: Loss.
        """
        pass

    @abstractmethod
    @torch.inference_mode()
    def sample(
        self, data: Any, n_steps: int, t_start: float = 0.0, use_ema: bool = False
    ) -> Float[torch.Tensor, "n_steps n_edges"]:
        """Abstract method to generate samples using generative framework.

        :param data: PyG batch object.
        :param n_steps: Number of generation steps.
        :param t_start: Starting time. Defaults to `0.0`.
        :returns: Trajectory of edge probabilities.
        """
        pass
