#!/usr/bin/env python3
"""
Flexible Evaluation Framework for CNCRC
======================================

This module provides flexible evaluation metrics that allow users to balance
between traditional CP metrics and CNCRC-specific metrics according to their
application needs.

The framework supports:
1. Layered evaluation (primary approach)
2. Optional fusion metrics (user-configurable)
3. Application-specific scoring functions
"""

from typing import Dict, List, Tuple, Optional, Callable, Any
import numpy as np
from dataclasses import dataclass, field
from enum import Enum
import warnings

class EvaluationMode(Enum):
    """Evaluation mode enumeration."""
    LAYERED = "layered"           # Default: layered evaluation
    CLINICAL = "clinical"         # Clinical decision focused
    BALANCED = "balanced"         # Balanced traditional + CNCRC
    CONSERVATIVE = "conservative" # Conservative application
    CUSTOM = "custom"            # User-defined weights


@dataclass
class FlexibleEvaluationConfig:
    """Configuration for flexible evaluation."""
    
    # Primary evaluation mode
    mode: EvaluationMode = EvaluationMode.LAYERED
    
    # Constraint thresholds
    alpha: float = 0.1  # Non-coverage risk threshold
    coverage_tolerance: float = 0.1  # Coverage difference tolerance
    set_size_tolerance: float = 2.0  # Set size difference tolerance
    
    # Fusion metric weights (only used when mode != LAYERED)
    weight_ambiguity: float = 0.6    # Weight for ambiguity improvement
    weight_coverage: float = 0.2     # Weight for coverage compatibility  
    weight_set_size: float = 0.2     # Weight for set size compatibility
    
    # Custom scoring function (only used when mode == CUSTOM)
    custom_scorer: Optional[Callable] = None
    
    # Normalization parameters
    normalize_metrics: bool = True
    baseline_method: str = "StandardCP"  # Reference method for normalization


class FlexibleEvaluator:
    """
    Flexible evaluator supporting both layered and fusion evaluation approaches.
    
    This class implements the "primary layered + auxiliary fusion" strategy,
    allowing users to choose evaluation approaches based on their specific needs.
    """
    
    def __init__(self, config: FlexibleEvaluationConfig):
        """Initialize the flexible evaluator."""
        self.config = config
        self._validate_config()
    
    def _validate_config(self):
        """Validate configuration parameters."""
        if self.config.mode == EvaluationMode.CUSTOM and self.config.custom_scorer is None:
            raise ValueError("Custom scorer must be provided when mode is CUSTOM")
        
        weights_sum = (self.config.weight_ambiguity + 
                      self.config.weight_coverage + 
                      self.config.weight_set_size)
        if abs(weights_sum - 1.0) > 1e-6:
            warnings.warn(f"Fusion weights sum to {weights_sum:.3f}, not 1.0. "
                         "This may lead to unexpected scaling.")
    
    def evaluate_method(
        self,
        method_results: Dict[str, Any],
        baseline_results: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Evaluate a method using the configured evaluation approach.
        
        Args:
            method_results: Results dictionary containing evaluation metrics
            baseline_results: Optional baseline results for normalization
            
        Returns:
            Evaluation results including scores and rankings
        """
        # Extract core metrics
        metrics = self._extract_metrics(method_results)
        
        # Layered evaluation (always computed)
        layered_eval = self._layered_evaluation(metrics)
        
        # Fusion evaluation (computed based on mode)
        fusion_eval = {}
        if self.config.mode != EvaluationMode.LAYERED:
            fusion_eval = self._fusion_evaluation(metrics, baseline_results)
        
        # Combine results
        return {
            **layered_eval,
            **fusion_eval,
            'evaluation_mode': self.config.mode.value,
            'raw_metrics': metrics
        }
    
    def _extract_metrics(self, results: Dict[str, Any]) -> Dict[str, float]:
        """Extract and validate required metrics from results."""
        required_metrics = [
            'empirical_non_coverage_risk',
            'covered_ambiguity_loss_mean', 
            'empirical_coverage',
            'average_set_size'
        ]
        
        metrics = {}
        for metric in required_metrics:
            if metric not in results:
                raise KeyError(f"Required metric '{metric}' not found in results")
            metrics[metric] = float(results[metric])
        
        return metrics
    
    def _layered_evaluation(self, metrics: Dict[str, float]) -> Dict[str, Any]:
        """
        Perform layered evaluation following the constraint optimization paradigm.
        
        This is the primary evaluation approach that:
        1. Checks constraint satisfaction (safety)
        2. Evaluates optimization performance (effectiveness)  
        3. Validates compatibility (practicality)
        """
        non_coverage_risk = metrics['empirical_non_coverage_risk']
        ambiguity_risk = metrics['covered_ambiguity_loss_mean']
        coverage = metrics['empirical_coverage']
        set_size = metrics['average_set_size']
        
        # Layer 1: Constraint satisfaction
        constraint_satisfied = non_coverage_risk <= self.config.alpha
        
        # Layer 2: Optimization performance (lower is better)
        optimization_score = -ambiguity_risk  # Negative for "higher is better"
        
        # Layer 3: Compatibility assessment
        # Note: This requires baseline comparison for absolute assessment
        coverage_acceptable = True  # Will be updated in comparative evaluation
        set_size_acceptable = True  # Will be updated in comparative evaluation
        
        return {
            'layered_constraint_satisfied': constraint_satisfied,
            'layered_optimization_score': optimization_score,
            'layered_coverage_acceptable': coverage_acceptable,
            'layered_set_size_acceptable': set_size_acceptable,
            'layered_overall_valid': constraint_satisfied,  # Basic validity
            'layer_1_score': 1.0 if constraint_satisfied else 0.0,
            'layer_2_score': optimization_score,
            'layer_3_score': None  # Requires baseline comparison
        }
    
    def _fusion_evaluation(
        self, 
        metrics: Dict[str, float],
        baseline_results: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Perform fusion evaluation using weighted combination of metrics.
        
        This provides auxiliary scoring for users who prefer single-value rankings.
        """
        if self.config.mode == EvaluationMode.CUSTOM:
            return self._custom_evaluation(metrics, baseline_results)
        
        # Normalize metrics if baseline is provided
        if baseline_results and self.config.normalize_metrics:
            normalized_metrics = self._normalize_metrics(metrics, baseline_results)
        else:
            normalized_metrics = metrics.copy()
        
        # Apply mode-specific scoring
        if self.config.mode == EvaluationMode.CLINICAL:
            fusion_score = self._clinical_scoring(normalized_metrics)
        elif self.config.mode == EvaluationMode.BALANCED:
            fusion_score = self._balanced_scoring(normalized_metrics)
        elif self.config.mode == EvaluationMode.CONSERVATIVE:
            fusion_score = self._conservative_scoring(normalized_metrics)
        else:
            raise ValueError(f"Unsupported evaluation mode: {self.config.mode}")
        
        return {
            'fusion_score': fusion_score,
            'fusion_mode': self.config.mode.value,
            'normalized_metrics': normalized_metrics
        }
    
    def _clinical_scoring(self, metrics: Dict[str, float]) -> float:
        """
        Clinical decision focused scoring.
        
        Prioritizes safety (constraint satisfaction) above all else,
        then optimizes for decision quality (ambiguity reduction).
        """
        non_coverage_risk = metrics['empirical_non_coverage_risk']
        ambiguity_risk = metrics['covered_ambiguity_loss_mean']
        
        # Hard constraint: if safety violated, score is 0
        if non_coverage_risk > self.config.alpha:
            return 0.0
        
        # Otherwise, score based on ambiguity improvement
        # Higher ambiguity risk = lower score
        return max(0.0, 1.0 - ambiguity_risk)
    
    def _balanced_scoring(self, metrics: Dict[str, float]) -> float:
        """
        Balanced scoring combining CNCRC and traditional metrics.
        
        Uses configurable weights to balance between innovation and compatibility.
        """
        non_coverage_risk = metrics['empirical_non_coverage_risk']
        ambiguity_risk = metrics['covered_ambiguity_loss_mean']
        coverage = metrics['empirical_coverage']
        set_size = metrics['average_set_size']
        
        # Constraint penalty
        constraint_penalty = max(0.0, non_coverage_risk - self.config.alpha)
        
        # Component scores (normalized to [0,1])
        ambiguity_score = max(0.0, 1.0 - ambiguity_risk)
        coverage_score = min(1.0, coverage)  # Assuming coverage target is 1.0
        set_size_score = max(0.0, 1.0 - set_size / 10.0)  # Rough normalization
        
        # Weighted combination
        fusion_score = (self.config.weight_ambiguity * ambiguity_score +
                       self.config.weight_coverage * coverage_score +
                       self.config.weight_set_size * set_size_score)
        
        # Apply constraint penalty
        return max(0.0, fusion_score - constraint_penalty * 10.0)
    
    def _conservative_scoring(self, metrics: Dict[str, float]) -> float:
        """
        Conservative scoring that requires both safety and compatibility.
        
        Takes the minimum of ambiguity improvement and traditional compatibility.
        """
        non_coverage_risk = metrics['empirical_non_coverage_risk']
        ambiguity_risk = metrics['covered_ambiguity_loss_mean']
        coverage = metrics['empirical_coverage']
        set_size = metrics['average_set_size']
        
        # Safety check
        if non_coverage_risk > self.config.alpha:
            return 0.0
        
        # Component scores
        ambiguity_score = max(0.0, 1.0 - ambiguity_risk)
        traditional_score = min(coverage, 1.0 - set_size / 10.0)
        
        # Conservative: take minimum (weakest link)
        return min(ambiguity_score, traditional_score)
    
    def _custom_evaluation(
        self,
        metrics: Dict[str, float],
        baseline_results: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Apply user-defined custom scoring function."""
        if self.config.custom_scorer is None:
            raise ValueError("Custom scorer not configured")
        
        fusion_score = self.config.custom_scorer(metrics, baseline_results, self.config)
        
        return {
            'fusion_score': fusion_score,
            'fusion_mode': 'custom'
        }
    
    def _normalize_metrics(
        self,
        metrics: Dict[str, float],
        baseline_results: Dict[str, Any]
    ) -> Dict[str, float]:
        """Normalize metrics relative to baseline performance."""
        baseline_metrics = self._extract_metrics(baseline_results)
        normalized = {}
        
        for key, value in metrics.items():
            baseline_value = baseline_metrics[key]
            if baseline_value != 0:
                # Relative improvement/degradation
                normalized[key] = (baseline_value - value) / baseline_value
            else:
                normalized[key] = value
        
        return normalized
    
    def compare_methods(
        self,
        results_dict: Dict[str, Dict[str, Any]],
        include_fusion_ranking: bool = True
    ) -> Dict[str, Any]:
        """
        Compare multiple methods using the configured evaluation approach.
        
        Args:
            results_dict: Dictionary mapping method names to their results
            include_fusion_ranking: Whether to include fusion-based ranking
            
        Returns:
            Comprehensive comparison results
        """
        # Get baseline results for normalization
        baseline_results = None
        if self.config.baseline_method in results_dict:
            baseline_results = results_dict[self.config.baseline_method]
        
        # Evaluate each method
        evaluations = {}
        for method_name, results in results_dict.items():
            evaluations[method_name] = self.evaluate_method(results, baseline_results)
        
        # Create layered ranking
        layered_ranking = self._create_layered_ranking(evaluations)
        
        # Create fusion ranking if requested
        fusion_ranking = {}
        if include_fusion_ranking and self.config.mode != EvaluationMode.LAYERED:
            fusion_ranking = self._create_fusion_ranking(evaluations)
        
        return {
            'evaluations': evaluations,
            'layered_ranking': layered_ranking,
            'fusion_ranking': fusion_ranking,
            'evaluation_config': self.config,
            'primary_approach': 'layered',
            'auxiliary_approach': self.config.mode.value if self.config.mode != EvaluationMode.LAYERED else None
        }
    
    def _create_layered_ranking(self, evaluations: Dict[str, Dict]) -> Dict[str, Any]:
        """Create ranking based on layered evaluation."""
        # Separate constraint-satisfying and violating methods
        satisfying_methods = []
        violating_methods = []
        
        for method_name, eval_result in evaluations.items():
            if eval_result['layered_constraint_satisfied']:
                satisfying_methods.append((method_name, eval_result['layered_optimization_score']))
            else:
                violating_methods.append((method_name, eval_result['layered_optimization_score']))
        
        # Sort each group
        satisfying_methods.sort(key=lambda x: x[1], reverse=True)  # Higher optimization score is better
        violating_methods.sort(key=lambda x: x[1], reverse=True)
        
        # Combine rankings
        final_ranking = [method[0] for method in satisfying_methods] + [method[0] for method in violating_methods]
        
        return {
            'ranking': final_ranking,
            'constraint_satisfying': [m[0] for m in satisfying_methods],
            'constraint_violating': [m[0] for m in violating_methods],
            'ranking_logic': 'constraint_satisfaction_then_optimization'
        }
    
    def _create_fusion_ranking(self, evaluations: Dict[str, Dict]) -> Dict[str, Any]:
        """Create ranking based on fusion scores."""
        method_scores = []
        for method_name, eval_result in evaluations.items():
            if 'fusion_score' in eval_result:
                method_scores.append((method_name, eval_result['fusion_score']))
        
        # Sort by fusion score (higher is better)
        method_scores.sort(key=lambda x: x[1], reverse=True)
        
        return {
            'ranking': [method[0] for method in method_scores],
            'scores': dict(method_scores),
            'ranking_logic': f'fusion_score_{self.config.mode.value}'
        }


def create_evaluation_configs() -> Dict[str, FlexibleEvaluationConfig]:
    """Create predefined evaluation configurations for common use cases."""
    configs = {}
    
    # Default layered evaluation
    configs['layered'] = FlexibleEvaluationConfig(
        mode=EvaluationMode.LAYERED
    )
    
    # Clinical decision focused
    configs['clinical'] = FlexibleEvaluationConfig(
        mode=EvaluationMode.CLINICAL,
        alpha=0.05  # Stricter safety requirement
    )
    
    # Balanced approach
    configs['balanced'] = FlexibleEvaluationConfig(
        mode=EvaluationMode.BALANCED,
        weight_ambiguity=0.6,
        weight_coverage=0.2,
        weight_set_size=0.2
    )
    
    # Conservative approach
    configs['conservative'] = FlexibleEvaluationConfig(
        mode=EvaluationMode.CONSERVATIVE,
        coverage_tolerance=0.05,  # Stricter compatibility requirements
        set_size_tolerance=1.0
    )
    
    return configs


# Example usage functions
def demo_flexible_evaluation():
    """Demonstrate flexible evaluation with different configurations."""
    
    # Sample method results
    sample_results = {
        'CNCRC_Core': {
            'empirical_non_coverage_risk': 0.0634,
            'covered_ambiguity_loss_mean': 0.0423,
            'empirical_coverage': 0.920,
            'average_set_size': 2.34
        },
        'StandardCP': {
            'empirical_non_coverage_risk': 0.0892,
            'covered_ambiguity_loss_mean': 0.0789,
            'empirical_coverage': 0.880,
            'average_set_size': 2.58
        }
    }
    
    # Test different evaluation modes
    configs = create_evaluation_configs()
    
    for config_name, config in configs.items():
        print(f"\n=== {config_name.upper()} EVALUATION ===")
        evaluator = FlexibleEvaluator(config)
        comparison = evaluator.compare_methods(sample_results)
        
        print(f"Primary ranking: {comparison['layered_ranking']['ranking']}")
        if comparison['fusion_ranking']:
            print(f"Fusion ranking: {comparison['fusion_ranking']['ranking']}")


if __name__ == '__main__':
    demo_flexible_evaluation()


