from abc import ABC, abstractmethod
from typing import Dict, Any

import numpy as np

from ..transform.transform import TabularTransformSet, TabularTransform


class EvalBase(ABC):
    """
    An abstract base class for evaluation modules that operate on tabular data.

    This class uses a `TabularTransformSet` to determine how tabular data should
    be preprocessed for evaluation, optionally including a target column or not.

    Subclasses should implement the `_evaluation` method, which performs the
    actual evaluation logic.

    Example:
        class MyEvaluator(EvalBase):
            def _evaluation(self, data: pd.DataFrame, **kwargs) -> Dict:
                # Implement your evaluation logic here
                # Return a dictionary containing metrics or other results

        transform_set = TabularTransformSet(...)
        evaluator = MyEvaluator(transform_set, target_column=True)
        results = evaluator(data_to_evaluate)
        # `results` might contain {"accuracy": 0.98, "f1_score": 0.96}
    """

    def __init__(
            self,
            transform_set: TabularTransformSet,
            drop_target_column: bool,
            **kwargs
    ):
        """
        Initialize the evaluator.

        Args:
            transform_set (TabularTransformSet):
                A collection of tabular transforms, typically containing
                both target-inclusive and target-exclusive transforms.
            drop_target_column (bool):
                If True, use the transform that includes the target column.
                If False, use the transform that excludes the target column.
            **kwargs:
                Additional arguments for evaluator or subclass.
        """
        super().__init__()
        self._transform_set = transform_set
        if drop_target_column:
            self.transform: TabularTransform = transform_set.no_target
        else:
            self.transform: TabularTransform = transform_set.target

    @property
    def name(self) -> str:
        """
        Return the class name of this evaluator.

        Returns:
            str: The evaluator's class name.
        """
        return type(self).__name__

    @property
    def metadata(self) -> Dict[str, Any]:
        """
        Return metadata describing the columns in the transform.

        This metadata dict can be used to inform other systems about the nature
        of each column (e.g., categorical vs. numerical).

        Returns:
            dict: A metadata dictionary containing:
                - 'columns': A dict mapping column names to metadata about that column.
        """
        columns = {
            col: {"sdtype": "categorical"}
            for col in self.transform.categorical_columns
        }
        columns.update({
            col: {"sdtype": "numerical", "compute_representation": "Float"}
            for col, dtype in zip(self.transform.numerical_columns, self.transform.dtypes)
        })
        return {"columns": columns}

    def evaluation(self, *args, **kwargs) -> Dict[str, Any]:
        result = self._evaluation(*args, **kwargs)
        for key, item in result.items():
            if hasattr(item, 'shape') and hasattr(item, 'item') and np.prod(item.shape) == 1:
                result[key] = item.item()
        return result

    @abstractmethod
    def _evaluation(self, *args, **kwargs) -> Dict[str, Any]:
        """
        Abstract method containing the core evaluation logic.

        Subclasses must implement this method to define how the data is
        evaluated and how the results are returned.

        Returns:
            dict: A dictionary of evaluation metrics or other results.
        """
        raise NotImplementedError("Subclasses must implement `_evaluation`.")

    __call__ = evaluation
