from abc import ABC, abstractmethod
from typing import Any, MutableMapping, Sequence, Union

from datasets.utils.logging import disable_progress_bar
from pandas import DataFrame, Series
from torch import Tensor
from torch.utils.data import DataLoader

disable_progress_bar()


class FeatureExtractor(ABC):
    """Base class for all Bayesian optimization datasets."""

    def __init__(self, num_outputs: int):
        self.num_outputs = num_outputs

        # To be defined in subclasses
        self.input_col: str
        self.target_cols: Sequence[str]
        self.obj_str: Sequence[str]
        self.maximization: bool

    @abstractmethod
    def get_dataloader(
        self,
        pandas_dataset: DataFrame,
        batch_size: int = 16,
        shuffle: bool = False,
        append_eos: bool = True,
        standardize_y: bool = False
    ) -> DataLoader:
        raise NotImplementedError

    @abstractmethod
    def _get_columns_to_remove(self) -> Sequence[str]:
        """Get columns to remove in the final, preprocessed data. This is useful when
        the Pandas DataFrame contains auxiliary informations inrrelevant for the model.

        Returns:
        --------
        cols:
            Columns to remove from the dataset
        """
        raise NotImplementedError

    def _get_targets(self, row: Union[Series, MutableMapping[str, Any]]) -> Tensor:
        """Given a single row in a DataFrame (i.e. a Series) or in a list of dicts,
        extract the value that corresponds to the targets.

        Arguments:
        ----------
        row:
            A single row of the raw dataset.

        Returns:
        --------
        targets:
            Regression target(s). Shape (self.num_outputs,).
        """
        return Tensor([row[col] for col in self.target_cols])
