"""Model tracking utilities for DriveGuard evaluation system.

This module provides functionality to track which models are used across
different components in the DriveGuard workflow, enabling comprehensive
model comparison and performance analysis.
"""

from typing import Dict, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime
import json

from .settings import settings


@dataclass
class ComponentModelInfo:
    """Information about a model used by a specific component."""
    component_name: str
    model_key: str  # Key from settings (e.g., 'openai_gpt_4o')
    model_name: str  # Actual model name (e.g., 'gpt-4o')
    model_type: str = "llm"  # 'llm' or 'embedding'
    custom_override: bool = False  # True if overridden from default


class ModelTracker:
    """Track models used across all DriveGuard components."""
    
    def __init__(self):
        self.components: Dict[str, ComponentModelInfo] = {}
        self.timestamp = datetime.utcnow().isoformat() + "Z"
        
    def record_component_model(
        self, 
        component_name: str, 
        model_key: str, 
        model_type: str = "llm",
        custom_override: bool = False
    ):
        """Record which model a component is using.
        
        Args:
            component_name: Name of the component (e.g., 'video_annotation')
            model_key: Model key from settings (e.g., 'openai_gpt_4o')
            model_type: Type of model ('llm' or 'embedding')
            custom_override: Whether this model was explicitly overridden
        """
        # With the new format, model_key is already the full model string (e.g., "openai:gpt-4o")
        model_name = model_key
            
        self.components[component_name] = ComponentModelInfo(
            component_name=component_name,
            model_key=model_key,
            model_name=model_name,
            model_type=model_type,
            custom_override=custom_override
        )
    
    def get_models_summary(self) -> Dict[str, str]:
        """Get a simple summary of models used by each component.
        
        Returns:
            Dict mapping component names to model names
        """
        return {
            comp_name: info.model_name 
            for comp_name, info in self.components.items()
        }
    
    def get_detailed_info(self) -> Dict[str, Any]:
        """Get detailed model information including overrides and types.
        
        Returns:
            Detailed dictionary with model information
        """
        return {
            "timestamp": self.timestamp,
            "components": {
                comp_name: {
                    "model_key": info.model_key,
                    "model_name": info.model_name,
                    "model_type": info.model_type,
                    "custom_override": info.custom_override
                }
                for comp_name, info in self.components.items()
            },
            "summary": self.get_models_summary()
        }
    
    def get_model_for_component(self, component_name: str) -> Optional[str]:
        """Get the model name used by a specific component.
        
        Args:
            component_name: Name of the component
            
        Returns:
            Model name if recorded, None otherwise
        """
        info = self.components.get(component_name)
        return info.model_name if info else None
    
    def has_custom_overrides(self) -> bool:
        """Check if any components have custom model overrides.
        
        Returns:
            True if any component uses a custom model override
        """
        return any(info.custom_override for info in self.components.values())
    
    def get_custom_overrides(self) -> Dict[str, str]:
        """Get components that have custom model overrides.
        
        Returns:
            Dict of component names to model names for overridden components
        """
        return {
            comp_name: info.model_name
            for comp_name, info in self.components.items()
            if info.custom_override
        }
    
    def export_to_json(self, filepath: str):
        """Export model tracking information to JSON file.
        
        Args:
            filepath: Path to save the JSON file
        """
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.get_detailed_info(), f, indent=2, ensure_ascii=False)
    
    @classmethod
    def load_from_json(cls, filepath: str) -> 'ModelTracker':
        """Load model tracking information from JSON file.
        
        Args:
            filepath: Path to the JSON file
            
        Returns:
            ModelTracker instance with loaded information
        """
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        tracker = cls()
        tracker.timestamp = data.get("timestamp", tracker.timestamp)
        
        for comp_name, comp_info in data.get("components", {}).items():
            tracker.record_component_model(
                component_name=comp_name,
                model_key=comp_info["model_key"],
                model_type=comp_info["model_type"],
                custom_override=comp_info["custom_override"]
            )
        
        return tracker


def get_default_models() -> Dict[str, str]:
    """Get the default model configuration for all components.
    
    Returns:
        Dict mapping component names to default model keys
    """
    return {
        "video_annotation": settings.app.llm['multimodal'],
        "scene_extraction": settings.app.llm['fast'], 
        "traffic_rule_checking": settings.app.llm['fast'],
        "accident_analysis_retrieval": settings.app.llm['fast'],
        "accident_analysis_main": settings.app.llm['main'],
        "driving_assessment": settings.app.llm['main'],
        "embeddings": settings.app.embedding['main']
    }


def create_default_tracker() -> ModelTracker:
    """Create a ModelTracker with default model configuration.
    
    Returns:
        ModelTracker instance with default models recorded
    """
    tracker = ModelTracker()
    defaults = get_default_models()
    
    for component, model_key in defaults.items():
        model_type = "embedding" if component == "embeddings" else "llm"
        tracker.record_component_model(
            component_name=component,
            model_key=model_key,
            model_type=model_type,
            custom_override=False
        )
    
    return tracker