"""Performance timing utilities for DriveGuard evaluation system.

This module provides functionality to measure and track execution time
for different components in the DriveGuard workflow, enabling performance
analysis and optimization.
"""

import time
from typing import Dict, Optional, Any, List
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime
import json


@dataclass
class TimingRecord:
    """Record of timing information for a component or operation."""
    start_time: float
    end_time: Optional[float] = None
    duration: Optional[float] = None
    sub_operations: Dict[str, float] = field(default_factory=dict)
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def finish(self):
        """Mark the timing record as finished and calculate duration."""
        if self.end_time is None:
            self.end_time = time.time()
            self.duration = self.end_time - self.start_time
    
    def is_finished(self) -> bool:
        """Check if this timing record has been finished."""
        return self.end_time is not None
    
    def get_summary(self) -> Dict[str, Any]:
        """Get a summary of timing information."""
        summary = {
            "duration": self.duration,
            "sub_operations": self.sub_operations.copy()
        }
        if self.metadata:
            summary["metadata"] = self.metadata.copy()
        return summary


class PerformanceTimer:
    """Track execution time for DriveGuard components and operations."""
    
    def __init__(self):
        self.records: Dict[str, TimingRecord] = {}
        self.session_start = time.time()
        self.session_timestamp = datetime.utcnow().isoformat() + "Z"
        
    def start_component(self, component_name: str, metadata: Optional[Dict[str, Any]] = None):
        """Start timing a component.
        
        Args:
            component_name: Name of the component to time
            metadata: Optional metadata about the component operation
        """
        if component_name in self.records and not self.records[component_name].is_finished():
            # If component is already being timed, finish it first
            self.records[component_name].finish()
            
        self.records[component_name] = TimingRecord(
            start_time=time.time(),
            metadata=metadata or {}
        )
    
    def end_component(self, component_name: str) -> Optional[float]:
        """End timing a component and return the duration.
        
        Args:
            component_name: Name of the component to stop timing
            
        Returns:
            Duration in seconds, or None if component wasn't being timed
        """
        if component_name in self.records:
            record = self.records[component_name]
            if not record.is_finished():
                record.finish()
                return record.duration
        return None
    
    def record_sub_operation(self, component_name: str, operation_name: str, duration: float):
        """Record timing for a sub-operation within a component.
        
        Args:
            component_name: Name of the parent component
            operation_name: Name of the sub-operation
            duration: Duration of the sub-operation in seconds
        """
        if component_name in self.records:
            self.records[component_name].sub_operations[operation_name] = duration
    
    @contextmanager
    def time_operation(self, component_name: str, operation_name: str):
        """Context manager for timing sub-operations.
        
        Args:
            component_name: Name of the parent component
            operation_name: Name of the sub-operation
            
        Example:
            with timer.time_operation("video_annotation", "preprocessing"):
                # Your preprocessing code here
                pass
        """
        start_time = time.time()
        try:
            yield
        finally:
            duration = time.time() - start_time
            self.record_sub_operation(component_name, operation_name, duration)
    
    def get_component_timing(self, component_name: str) -> Optional[Dict[str, Any]]:
        """Get timing information for a specific component.
        
        Args:
            component_name: Name of the component
            
        Returns:
            Timing summary dictionary, or None if component not found
        """
        record = self.records.get(component_name)
        if record:
            if not record.is_finished():
                record.finish()  # Finish if not already finished
            return record.get_summary()
        return None
    
    def get_all_timings(self) -> Dict[str, Any]:
        """Get timing information for all components.
        
        Returns:
            Dictionary with timing information for all components
        """
        # Finish any unfinished records
        for record in self.records.values():
            if not record.is_finished():
                record.finish()
        
        total_session_time = time.time() - self.session_start
        
        component_timings = {}
        total_measured_time = 0
        
        for comp_name, record in self.records.items():
            timing_info = record.get_summary()
            component_timings[comp_name] = timing_info
            if record.duration:
                total_measured_time += record.duration
        
        return {
            "session_timestamp": self.session_timestamp,
            "total_session_time": total_session_time,
            "total_measured_time": total_measured_time,
            "component_timings": component_timings,
            "summary": {
                "fastest_component": self._get_fastest_component(),
                "slowest_component": self._get_slowest_component(),
                "total_components": len(self.records)
            }
        }
    
    def get_simple_timings(self) -> Dict[str, float]:
        """Get simple component name to duration mapping.
        
        Returns:
            Dict mapping component names to durations in seconds
        """
        timings = {}
        for comp_name, record in self.records.items():
            if not record.is_finished():
                record.finish()
            timings[comp_name] = record.duration or 0.0
        return timings
    
    def _get_fastest_component(self) -> Optional[Dict[str, Any]]:
        """Get information about the fastest component."""
        finished_records = {
            name: record for name, record in self.records.items() 
            if record.is_finished() and record.duration
        }
        
        if not finished_records:
            return None
            
        fastest_name = min(finished_records.keys(), 
                          key=lambda x: finished_records[x].duration)
        return {
            "component": fastest_name,
            "duration": finished_records[fastest_name].duration
        }
    
    def _get_slowest_component(self) -> Optional[Dict[str, Any]]:
        """Get information about the slowest component."""
        finished_records = {
            name: record for name, record in self.records.items() 
            if record.is_finished() and record.duration
        }
        
        if not finished_records:
            return None
            
        slowest_name = max(finished_records.keys(), 
                          key=lambda x: finished_records[x].duration)
        return {
            "component": slowest_name,
            "duration": finished_records[slowest_name].duration
        }
    
    def export_to_json(self, filepath: str):
        """Export timing information to JSON file.
        
        Args:
            filepath: Path to save the JSON file
        """
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.get_all_timings(), f, indent=2, ensure_ascii=False)
    
    def reset(self):
        """Reset all timing records and start a new session."""
        self.records.clear()
        self.session_start = time.time()
        self.session_timestamp = datetime.utcnow().isoformat() + "Z"


class ComponentPerformanceContext:
    """Context manager for easy component timing with automatic cleanup."""
    
    def __init__(self, timer: PerformanceTimer, component_name: str, 
                 metadata: Optional[Dict[str, Any]] = None):
        self.timer = timer
        self.component_name = component_name
        self.metadata = metadata
        
    def __enter__(self):
        self.timer.start_component(self.component_name, self.metadata)
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.timer.end_component(self.component_name)
        
    def time_sub_operation(self, operation_name: str):
        """Get a context manager for timing sub-operations."""
        return self.timer.time_operation(self.component_name, operation_name)