from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, MutableMapping, Optional, Sequence, Tuple

from laplace import BaseLaplace
from pandas import Series
from problems.data_processor import DataProcessor
from torch import Tensor
from torch.nn import Module

from utils.configs import LaplaceConfig


class FinetunedSurrogate(ABC):
    """
    Base class for LLM BayesOpt variants.

    Parameters:
    -----------
    get_net: function of None -> torch.nn.Module
        A function returning the freshly-initialized regression NN
        attached on top of the LLM feature extractor.

    llm_feature_extractor: torch.nn.Module
        The LLM feature extractor. It takes strings and outputs feature vectors.

    training_set: list of pd.Series
        Initial training data. E.g. obtained via random search.

    data_processor: DataProcessor
        Data processor to process the pandas training set.

    bnn: Laplace, optional, default=None
        When creating a new model from scratch, leave this at None.
        Use this only to update this model with a new observation during BayesOpt run.

    laplace_config: LaplaceConfig, optional, default=None
        Override configs for Laplace
    """

    def __init__(
        self,
        training_set: Sequence[Series],
        initialize_nn: Callable[[], Module],
        data_processor: DataProcessor,
        laplace: Optional[BaseLaplace] = None,
        laplace_config: Optional[LaplaceConfig] = None,
        device: str = "cuda",
    ) -> None:
        self.training_set: Sequence[Series] = training_set
        self.initialize_nn: Callable[[], Module] = initialize_nn
        self.data_processor = data_processor
        self.laplace_config: LaplaceConfig = (
            laplace_config if laplace_config is not None else LaplaceConfig()
        )
        self.device: str = device

        self.laplace: BaseLaplace
        if laplace is None:
            self.train_model()
        else:
            self.laplace = laplace

    @abstractmethod
    def train_model(self) -> None:
        """Train the netwok from `self.initialize_nn` using `self.training_set`."""
        raise NotImplementedError

    @abstractmethod
    def posterior(self, inputs: MutableMapping[str, Any]) -> Tuple[Tensor, Tensor]:
        """Given input tensors X, obtain the Laplace posterior predictive distribution
        p(f(X) | X_train, y_train), which is a Gaussian.

        Args:
        -----
        inputs:
            A dict-like object containing the underlying NN's inputs

        Returns:
        --------
        posterior:
            An independent Gaussian distribution where the mean and variance are
            both (batch_size, n_tasks)
        """
        raise NotImplementedError

    @abstractmethod
    def condition_on_observations(self, obs: Series) -> FinetunedSurrogate:
        raise NotImplementedError

    @property
    def num_outputs(self) -> int:
        """The number of outputs of the model."""
        return self.data_processor.num_outputs
