from abc import abstractmethod
from typing import Dict, Optional, Tuple, Union

import torch
from jaxtyping import Bool
from torch import Tensor

ABS_CLASS_ERR_MSG = "Method not implemented in abstract class"


class BaseFlowMatcher:

    def __init__(self, guidance_enabled: bool, dim: int):
        self.guidance_enabled = guidance_enabled
        self.dim = dim

    @abstractmethod
    def mask_n_zero_com(
        self, x: torch.Tensor, mask: torch.Tensor = None
    ) -> torch.Tensor:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def sample_noise(
        self,
        n: int,
        device: torch.device,
        shape: Tuple = tuple(),
        mask: Optional[Bool[Tensor, "* n"]] = None,
    ) -> torch.Tensor:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def interpolate(
        self,
        x_0: torch.Tensor,
        x_1: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def extract_clean_sample_from_batch(self, batch: Dict) -> torch.Tensor:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def nn_out_add_clean_sample_prediction(
        self,
        x_t: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor,
        nn_out: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def nn_out_add_simulation_tensor(
        self,
        x_t: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor,
        nn_out: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def compute_fm_loss(
        self,
        x_0: torch.Tensor,
        x_1: torch.Tensor,
        x_t: torch.Tensor,
        mask: torch.Tensor,
        t: torch.Tensor,
        x_1_pred: torch.Tensor,
    ) -> torch.Tensor:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def nn_out_add_guided_simulation_tensor(
        self,
        nn_out: Dict[str, torch.Tensor],
        nn_out_ag: Union[Dict[str, torch.Tensor], None],
        nn_out_ucond: Union[Dict[str, torch.Tensor], None],
        guidance_w: float,
        ag_ratio: float,
    ) -> Dict[str, torch.Tensor]:

        raise NotImplementedError(ABS_CLASS_ERR_MSG)

    @abstractmethod
    def simulation_step(
        self,
        x_t: torch.Tensor,
        nn_out: Dict[str, torch.Tensor],
        t: torch.Tensor,
        dt: float,
        gt: float,
        mask: torch.Tensor,
        simulation_step_params: Dict,
    ):

        raise NotImplementedError(ABS_CLASS_ERR_MSG)
