"""Timing metrics for model performance evaluation."""

from typing import Dict, List, Any, Optional
import statistics


class TimingMetrics:
    """Calculates timing-related metrics for model performance evaluation."""
    
    def extract_generation_time(self, system_output: Dict[str, Any]) -> Optional[float]:
        """Extract generation time from system output.
        
        Args:
            system_output: System output dictionary containing generation_time
            
        Returns:
            Generation time in seconds, or None if not available
        """
        try:
            gen_time = system_output.get('generation_time')
            if gen_time is not None:
                return float(gen_time)
        except (TypeError, ValueError):
            pass
        return None
    
    def calculate_timing_stats(self, generation_times: List[float]) -> Dict[str, float]:
        """Calculate timing statistics from a list of generation times.
        
        Args:
            generation_times: List of generation times in seconds
            
        Returns:
            Dictionary with timing statistics
        """
        if not generation_times:
            return {
                'mean_time': 0.0,
                'median_time': 0.0,
                'min_time': 0.0,
                'max_time': 0.0,
                'std_time': 0.0,
                'total_time': 0.0,
                'count': 0
            }
        
        # Filter out any invalid values
        valid_times = [t for t in generation_times if isinstance(t, (int, float)) and t >= 0]
        
        if not valid_times:
            return {
                'mean_time': 0.0,
                'median_time': 0.0,
                'min_time': 0.0,
                'max_time': 0.0,
                'std_time': 0.0,
                'total_time': 0.0,
                'count': 0
            }
        
        return {
            'mean_time': statistics.mean(valid_times),
            'median_time': statistics.median(valid_times),
            'min_time': min(valid_times),
            'max_time': max(valid_times),
            'std_time': statistics.stdev(valid_times) if len(valid_times) > 1 else 0.0,
            'total_time': sum(valid_times),
            'count': len(valid_times)
        }
    
    def compare_model_timing(self, model_timings: Dict[str, List[float]]) -> Dict[str, Any]:
        """Compare timing performance across multiple models.
        
        Args:
            model_timings: Dictionary mapping model names to lists of generation times
            
        Returns:
            Dictionary with timing comparison results
        """
        if not model_timings:
            return {}
        
        comparison = {}
        model_stats = {}
        
        # Calculate stats for each model
        for model, times in model_timings.items():
            model_stats[model] = self.calculate_timing_stats(times)
        
        # Find fastest and slowest models
        mean_times = {model: stats['mean_time'] for model, stats in model_stats.items() 
                     if stats['count'] > 0}
        
        if mean_times:
            fastest_model = min(mean_times, key=mean_times.get)
            slowest_model = max(mean_times, key=mean_times.get)
            
            comparison = {
                'model_stats': model_stats,
                'fastest_model': {
                    'model': fastest_model,
                    'mean_time': mean_times[fastest_model]
                },
                'slowest_model': {
                    'model': slowest_model,
                    'mean_time': mean_times[slowest_model]
                },
                'speed_ratio': mean_times[slowest_model] / mean_times[fastest_model] if mean_times[fastest_model] > 0 else 1.0
            }
        
        return comparison
    
    def format_time(self, seconds: float) -> str:
        """Format time duration in a human-readable way.
        
        Args:
            seconds: Time duration in seconds
            
        Returns:
            Formatted time string
        """
        if seconds < 1:
            return f"{seconds*1000:.0f}ms"
        elif seconds < 60:
            return f"{seconds:.1f}s"
        elif seconds < 3600:
            minutes = int(seconds // 60)
            remaining_seconds = seconds % 60
            return f"{minutes}m {remaining_seconds:.0f}s"
        else:
            hours = int(seconds // 3600)
            remaining_minutes = int((seconds % 3600) // 60)
            return f"{hours}h {remaining_minutes}m"
    
    def calculate_efficiency_score(self, quality_score: float, generation_time: float, 
                                 baseline_time: float = 10.0) -> float:
        """Calculate an efficiency score combining quality and speed.
        
        Args:
            quality_score: Quality score (0-1 or 0-10 scale)
            generation_time: Generation time in seconds
            baseline_time: Baseline time for comparison (default: 10 seconds)
            
        Returns:
            Efficiency score (higher is better)
        """
        if generation_time <= 0:
            return 0.0
        
        # Normalize quality score to 0-1 range
        if quality_score > 1.0:
            normalized_quality = quality_score / 10.0
        else:
            normalized_quality = quality_score
        
        # Time efficiency factor (1.0 at baseline, higher for faster, lower for slower)
        time_efficiency = baseline_time / generation_time
        
        # Combined efficiency score (geometric mean of quality and time efficiency)
        efficiency_score = (normalized_quality * time_efficiency) ** 0.5
        
        return efficiency_score