"""
Tests for the evaluator implementation.
"""

import pytest
import pandas as pd
import numpy as np

from src.evaluation.evaluator import Evaluator


class TestEvaluator:
    """Tests for the Evaluator class."""

    def setup_method(self):
        """Set up test fixtures."""
        self.evaluator = Evaluator()
        
        # Create test data
        self.list_data = [1, 2, 3, 4, 5]
        self.df_data = pd.DataFrame({
            "score": [1, 2, 3, 4, 5],
            "other": [10, 20, 30, 40, 50]
        })
        self.df_no_score = pd.DataFrame({
            "other": [10, 20, 30, 40, 50]
        })

    def test_init(self):
        """Test initialisation of Evaluator."""
        # Check that the default metrics were registered
        assert "mean" in self.evaluator.metrics
        assert "sum" in self.evaluator.metrics
        assert "min" in self.evaluator.metrics
        assert "max" in self.evaluator.metrics

    def test_register_metric(self):
        """Test registering a custom metric."""
        # Define a custom metric function
        def custom_metric(results):
            if isinstance(results, pd.DataFrame) and "score" in results.columns:
                return results["score"].std()
            elif isinstance(results, (list, tuple)):
                return np.std(results)
            else:
                return 0.0
        
        # Register the custom metric
        self.evaluator.register_metric("std", custom_metric)
        
        # Check that the metric was registered
        assert "std" in self.evaluator.metrics
        assert self.evaluator.metrics["std"] == custom_metric

    def test_calculate_mean(self):
        """Test calculating the mean metric."""
        # Test with list data
        assert self.evaluator._calculate_mean(self.list_data) == 3.0
        
        # Test with DataFrame data with score column
        assert self.evaluator._calculate_mean(self.df_data) == 3.0
        
        # Test with DataFrame data without score column
        assert self.evaluator._calculate_mean(self.df_no_score) == 0.0
        
        # Test with empty list
        assert self.evaluator._calculate_mean([]) == 0.0
        
        # Test with non-list, non-DataFrame data
        assert self.evaluator._calculate_mean("not a list or DataFrame") == 0.0

    def test_calculate_sum(self):
        """Test calculating the sum metric."""
        # Test with list data
        assert self.evaluator._calculate_sum(self.list_data) == 15
        
        # Test with DataFrame data with score column
        assert self.evaluator._calculate_sum(self.df_data) == 15
        
        # Test with DataFrame data without score column
        assert self.evaluator._calculate_sum(self.df_no_score) == 0.0
        
        # Test with empty list
        assert self.evaluator._calculate_sum([]) == 0
        
        # Test with non-list, non-DataFrame data
        assert self.evaluator._calculate_sum("not a list or DataFrame") == 0.0

    def test_calculate_min(self):
        """Test calculating the min metric."""
        # Test with list data
        assert self.evaluator._calculate_min(self.list_data) == 1
        
        # Test with DataFrame data with score column
        assert self.evaluator._calculate_min(self.df_data) == 1
        
        # Test with DataFrame data without score column
        assert self.evaluator._calculate_min(self.df_no_score) == 0.0
        
        # Test with empty list
        assert self.evaluator._calculate_min([]) == 0.0
        
        # Test with non-list, non-DataFrame data
        assert self.evaluator._calculate_min("not a list or DataFrame") == 0.0

    def test_calculate_max(self):
        """Test calculating the max metric."""
        # Test with list data
        assert self.evaluator._calculate_max(self.list_data) == 5
        
        # Test with DataFrame data with score column
        assert self.evaluator._calculate_max(self.df_data) == 5
        
        # Test with DataFrame data without score column
        assert self.evaluator._calculate_max(self.df_no_score) == 0.0
        
        # Test with empty list
        assert self.evaluator._calculate_max([]) == 0.0
        
        # Test with non-list, non-DataFrame data
        assert self.evaluator._calculate_max("not a list or DataFrame") == 0.0

    def test_evaluate(self):
        """Test evaluating results with metrics."""
        # Test with all metrics
        evaluation = self.evaluator.evaluate(self.list_data)
        assert evaluation["mean"] == 3.0
        assert evaluation["sum"] == 15
        assert evaluation["min"] == 1
        assert evaluation["max"] == 5
        
        # Test with specific metrics
        evaluation = self.evaluator.evaluate(self.list_data, metrics=["mean", "max"])
        assert "mean" in evaluation
        assert "max" in evaluation
        assert "sum" not in evaluation
        assert "min" not in evaluation
        
        # Test with non-existent metrics
        evaluation = self.evaluator.evaluate(self.list_data, metrics=["non_existent"])
        assert evaluation == {}

    def test_compare(self):
        """Test comparing two sets of results."""
        # Create two sets of results
        results_a = [1, 2, 3, 4, 5]
        results_b = [2, 4, 6, 8, 10]
        
        # Compare with all metrics
        comparison = self.evaluator.compare(results_a, results_b)
        
        # Check the mean comparison
        assert comparison["mean"]["a"] == 3.0
        assert comparison["mean"]["b"] == 6.0
        assert comparison["mean"]["diff"] == 3.0
        
        # Check the sum comparison
        assert comparison["sum"]["a"] == 15
        assert comparison["sum"]["b"] == 30
        assert comparison["sum"]["diff"] == 15
        
        # Check the min comparison
        assert comparison["min"]["a"] == 1
        assert comparison["min"]["b"] == 2
        assert comparison["min"]["diff"] == 1
        
        # Check the max comparison
        assert comparison["max"]["a"] == 5
        assert comparison["max"]["b"] == 10
        assert comparison["max"]["diff"] == 5
        
        # Test with specific metrics
        comparison = self.evaluator.compare(results_a, results_b, metrics=["mean"])
        assert "mean" in comparison
        assert "sum" not in comparison
        
        # Test with non-existent metrics
        comparison = self.evaluator.compare(results_a, results_b, metrics=["non_existent"])
        assert comparison == {}