"""
Base Evaluation Interface

This module defines the base interface for evaluators.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import pandas as pd


class EvaluationInterface(ABC):
    """
    Interface for evaluators.
    
    This abstract class defines the interface that all evaluator implementations
    must adhere to.
    """
    
    @abstractmethod
    def register_metric(self, name: str, metric_function: Any) -> None:
        """
        Register a metric function.
        
        Args:
            name: The name of the metric
            metric_function: The function to calculate the metric
        """
        pass
    
    @abstractmethod
    def evaluate(self, results: Any, metrics: Optional[List[str]] = None) -> Dict[str, float]:
        """
        Evaluate results using registered metrics.
        
        Args:
            results: The results to evaluate
            metrics: The metrics to use (if None, use all registered metrics)
            
        Returns:
            A dictionary of metric names to values
        """
        pass
    
    @abstractmethod
    def compare(
        self,
        results_a: Any,
        results_b: Any,
        metrics: Optional[List[str]] = None
    ) -> Dict[str, Dict[str, float]]:
        """
        Compare two sets of results.
        
        Args:
            results_a: The first set of results
            results_b: The second set of results
            metrics: The metrics to use (if None, use all registered metrics)
            
        Returns:
            A dictionary of metric names to dictionaries of {'a': value_a, 'b': value_b, 'diff': value_diff}
        """
        pass


class BaseEvaluator(EvaluationInterface):
    """
    Base implementation of the evaluator interface.
    
    This class provides a basic implementation of the evaluator interface that
    can be extended by specific evaluator implementations.
    """
    
    def __init__(self):
        """initialise the evaluator."""
        self.metrics = {}
    
    def register_metric(self, name: str, metric_function: Any) -> None:
        """
        Register a metric function.
        
        Args:
            name: The name of the metric
            metric_function: The function to calculate the metric
        """
        self.metrics[name] = metric_function
    
    def evaluate(self, results: Any, metrics: Optional[List[str]] = None) -> Dict[str, float]:
        """
        Evaluate results using registered metrics.
        
        Args:
            results: The results to evaluate
            metrics: The metrics to use (if None, use all registered metrics)
            
        Returns:
            A dictionary of metric names to values
        """
        if metrics is None:
            metrics = list(self.metrics.keys())
        
        evaluation = {}
        for metric in metrics:
            if metric in self.metrics:
                evaluation[metric] = self.metrics[metric](results)
        
        return evaluation
    
    def compare(
        self,
        results_a: Any,
        results_b: Any,
        metrics: Optional[List[str]] = None
    ) -> Dict[str, Dict[str, float]]:
        """
        Compare two sets of results.
        
        Args:
            results_a: The first set of results
            results_b: The second set of results
            metrics: The metrics to use (if None, use all registered metrics)
            
        Returns:
            A dictionary of metric names to dictionaries of {'a': value_a, 'b': value_b, 'diff': value_diff}
        """
        if metrics is None:
            metrics = list(self.metrics.keys())
        
        comparison = {}
        for metric in metrics:
            if metric in self.metrics:
                value_a = self.metrics[metric](results_a)
                value_b = self.metrics[metric](results_b)
                comparison[metric] = {
                    'a': value_a,
                    'b': value_b,
                    'diff': value_b - value_a,
                }
        
        return comparison