"""Metrics collection system for tracking evaluation performance."""

import time
import json
import logging
from pathlib import Path
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
import numpy as np


@dataclass
class Metric:
    """Represents a single metric measurement."""
    name: str
    value: float
    timestamp: float = field(default_factory=time.time)
    tags: Dict[str, str] = field(default_factory=dict)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        return asdict(self)


class MetricsCollector:
    """Collects and aggregates metrics during evaluation."""
    
    def __init__(self, logger: Optional[logging.Logger] = None):
        """Initialize metrics collector.
        
        Args:
            logger: Logger instance
        """
        self.logger = logger or logging.getLogger(__name__)
        self.metrics: Dict[str, List[Metric]] = {}
        self.start_time = time.time()
        
        # Track various performance metrics
        self.counters: Dict[str, int] = {}
        self.timers: Dict[str, List[float]] = {}
        self.gauges: Dict[str, float] = {}
    
    def record_metric(self, name: str, value: float, tags: Optional[Dict[str, str]] = None):
        """Record a single metric value.
        
        Args:
            name: Name of the metric
            value: Metric value
            tags: Optional tags for the metric
        """
        metric = Metric(name=name, value=value, tags=tags or {})
        
        if name not in self.metrics:
            self.metrics[name] = []
        self.metrics[name].append(metric)
    
    def increment_counter(self, name: str, value: int = 1):
        """Increment a counter metric.
        
        Args:
            name: Name of the counter
            value: Amount to increment by
        """
        if name not in self.counters:
            self.counters[name] = 0
        self.counters[name] += value
        
        self.record_metric(f"counter.{name}", self.counters[name])
    
    def record_duration(self, name: str, duration: float):
        """Record a duration/timing metric.
        
        Args:
            name: Name of the timer
            duration: Duration in seconds
        """
        if name not in self.timers:
            self.timers[name] = []
        self.timers[name].append(duration)
        
        self.record_metric(f"timer.{name}", duration)
    
    def set_gauge(self, name: str, value: float):
        """Set a gauge metric (current value).
        
        Args:
            name: Name of the gauge
            value: Current value
        """
        self.gauges[name] = value
        self.record_metric(f"gauge.{name}", value)
    
    def start_timer(self, name: str) -> float:
        """Start a timer and return the start time.
        
        Args:
            name: Name of the timer
            
        Returns:
            Start time
        """
        start_time = time.time()
        self.set_gauge(f"{name}.start", start_time)
        return start_time
    
    def stop_timer(self, name: str, start_time: float) -> float:
        """Stop a timer and record the duration.
        
        Args:
            name: Name of the timer
            start_time: Start time from start_timer()
            
        Returns:
            Duration in seconds
        """
        duration = time.time() - start_time
        self.record_duration(name, duration)
        return duration
    
    def record_evaluation_metrics(
        self,
        task_id: str,
        model_name: str,
        is_correct: bool,
        duration: float,
        tokens_used: int,
        tool_calls: int,
        follow_format: bool
    ):
        """Record metrics for a single evaluation.
        
        Args:
            task_id: Task identifier
            model_name: Model name
            is_correct: Whether the answer was correct
            duration: Evaluation duration in seconds
            tokens_used: Number of tokens used
            tool_calls: Number of tool calls made
            follow_format: Whether the answer was formatted correctly
        """
        tags = {'model': model_name, 'task': task_id}
        
        self.record_metric('evaluation.correct', float(is_correct), tags)
        self.record_metric('evaluation.duration', duration, tags)
        self.record_metric('evaluation.tokens', float(tokens_used), tags)
        self.record_metric('evaluation.tool_calls', float(tool_calls), tags)
        self.record_metric('evaluation.follow_format', float(follow_format), tags)
        # Update aggregates
        self.increment_counter(f'{model_name}.total')
        if is_correct:
            self.increment_counter(f'{model_name}.correct')
        if follow_format:
            self.increment_counter(f'{model_name}.follow_format')
        # Update model-specific timers
        self.record_duration(f'{model_name}.duration', duration)

    def get_summary(self) -> Dict[str, Any]:
        """Get summary of all collected metrics.
        
        Returns:
            Dictionary with metric summaries
        """
        summary = {
            'total_time': time.time() - self.start_time,
            'counters': dict(self.counters),
            'gauges': dict(self.gauges),
            'timers': {}
        }
        
        # Aggregate timer statistics
        for name, values in self.timers.items():
            if values:
                summary['timers'][name] = {
                    'count': len(values),
                    'total': sum(values),
                    'mean': np.mean(values),
                    'median': np.median(values),
                    'std': np.std(values),
                    'min': min(values),
                    'max': max(values)
                }
        
        # Calculate model-specific metrics
        model_metrics = {}
        for counter_name in self.counters:
            if '.total' in counter_name:
                model_name = counter_name.replace('.total', '')
                if model_name not in model_metrics:
                    model_metrics[model_name] = {}
                
                total = self.counters[counter_name]
                correct = self.counters.get(f'{model_name}.correct', 0)
                follow_format = self.counters.get(f'{model_name}.follow_format', 0)
                model_metrics[model_name] = {
                    'total': total,
                    'correct': correct,
                    'follow_format': follow_format,
                    'accuracy': correct / total if total > 0 else 0
                }
                
                # Add timing info if available
                if f'{model_name}.duration' in self.timers:
                    durations = self.timers[f'{model_name}.duration']
                    model_metrics[model_name]['avg_duration'] = np.mean(durations)
        
        if model_metrics:
            summary['models'] = model_metrics
        
        return summary
    
    def export_metrics(self, output_file: str):
        """Export all metrics to a JSON file.
        
        Args:
            output_file: Path to output file
        """
        output_path = Path(output_file).expanduser()
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Prepare export data
        export_data = {
            'timestamp': datetime.now().isoformat(),
            'summary': self.get_summary(),
            'metrics': {}
        }
        
        # Convert metrics to serializable format
        for name, metric_list in self.metrics.items():
            export_data['metrics'][name] = [m.to_dict() for m in metric_list]
        
        try:
            with open(output_path, 'w') as f:
                json.dump(export_data, f, indent=2)
            self.logger.info(f"Exported metrics to {output_path}")
        except Exception as e:
            self.logger.error(f"Failed to export metrics: {e}")
    
    def print_summary(self):
        """Print a formatted summary of metrics to console."""
        summary = self.get_summary()
        
        print("\n" + "="*60)
        print("METRICS SUMMARY")
        print("="*60)
        
        print(f"\nTotal Evaluation Time: {summary['total_time']:.2f} seconds")
        
        if 'models' in summary:
            print("\nModel Performance:")
            for model_name, metrics in summary['models'].items():
                print(f"\n  {model_name}:")
                print(f"    Total: {metrics['total']}")
                print(f"    Correct: {metrics['correct']}")
                print(f"    Accuracy: {metrics['accuracy']:.2%}")
                print(f"    Follow Format: {metrics['follow_format']:.2%}")
                if 'avg_duration' in metrics:
                    print(f"    Avg Duration: {metrics['avg_duration']:.2f}s")
        
        if summary['timers']:
            print("\nTiming Statistics:")
            for name, stats in summary['timers'].items():
                if not any(model in name for model in summary.get('models', {})):
                    print(f"\n  {name}:")
                    print(f"    Mean: {stats['mean']:.3f}s")
                    print(f"    Median: {stats['median']:.3f}s")
                    print(f"    Min/Max: {stats['min']:.3f}s / {stats['max']:.3f}s")
        
        print("\n" + "="*60)
