"""Script to update existing system output files to the new format.

This script updates existing system output JSON files to include the new
model tracking and performance timing information structure.
"""

import sys
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, Any

# Add project root to path
root = Path(__file__).parent.parent
sys.path.append(str(root))

from src.utils.log import logger
from src.utils.settings import settings


def get_default_models_used() -> Dict[str, str]:
    """Get the default models used by each component."""
    return {
        "video_annotation": settings.app.llm['multimodal'],
        "scene_extraction": settings.app.llm['fast'],
        "traffic_rule_checking": settings.app.llm['fast'],
        "accident_analysis_retrieval": settings.app.llm['fast'],
        "accident_analysis_main": settings.app.llm['main'],
        "driving_assessment": settings.app.llm['main'],
        "embeddings": settings.app.embedding['main']
    }


def get_model_names(model_keys: Dict[str, str]) -> Dict[str, str]:
    """Convert model keys to actual model names."""
    model_names = {}
    for component, model_key in model_keys.items():
        # With the new format, model_key is already the full model string (e.g., "openai:gpt-4o")
        model_names[component] = model_key
    return model_names


def estimate_component_timings(total_time: float) -> Dict[str, float]:
    """Estimate component timings based on typical proportions.
    
    Args:
        total_time: Total evaluation time
        
    Returns:
        Dictionary of estimated component timings
    """
    # Typical proportions based on system analysis
    proportions = {
        "video_annotation": 0.65,  # Video annotation is typically the slowest
        "scene_extraction": 0.08,
        "traffic_rule_checking": 0.10,
        "accident_analysis": 0.12,
        "driving_mentor": 0.05
    }
    
    return {
        component: total_time * proportion 
        for component, proportion in proportions.items()
    }


def update_system_output_file(file_path: Path) -> bool:
    """Update a single system output file to the new format.
    
    Args:
        file_path: Path to the system output JSON file
        
    Returns:
        True if updated successfully, False otherwise
    """
    try:
        # Read existing file
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # Check if already in new format
        if 'models_used' in data and 'component_timings' in data:
            logger.info(f"File already in new format: {file_path.name}")
            return True
        
        logger.info(f"Updating file: {file_path.name}")
        
        # Get default models
        default_models = get_default_models_used()
        model_names = get_model_names(default_models)
        
        # Estimate component timings if we have total evaluation time
        total_time = data.get('evaluation_time', 0)
        if total_time > 0:
            component_timings = estimate_component_timings(total_time)
        else:
            component_timings = {
                "video_annotation": 0,
                "scene_extraction": 0,
                "traffic_rule_checking": 0,
                "accident_analysis": 0,
                "driving_mentor": 0
            }
        
        # Create updated data structure
        updated_data = {
            "video_id": data["video_id"],
            "video_path": data["video_path"],
            "timestamp": datetime.utcnow().isoformat() + "Z",
            "experiment_name": None,
            "models_used": model_names,
            "component_timings": component_timings,
            "performance_metrics": {
                "session_timestamp": datetime.utcnow().isoformat() + "Z",
                "total_session_time": total_time,
                "total_measured_time": sum(component_timings.values()),
                "component_timings": {
                    comp: {"duration": timing} 
                    for comp, timing in component_timings.items()
                },
                "summary": {
                    "fastest_component": {
                        "component": min(component_timings.keys(), key=component_timings.get),
                        "duration": min(component_timings.values())
                    } if component_timings else None,
                    "slowest_component": {
                        "component": max(component_timings.keys(), key=component_timings.get),
                        "duration": max(component_timings.values())
                    } if component_timings else None,
                    "total_components": len(component_timings)
                }
            },
            "system_outputs": data["system_outputs"],
            "evaluation_time": data["evaluation_time"],
            "summary": data["summary"]
        }
        
        # Write updated file
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(updated_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Successfully updated: {file_path.name}")
        return True
        
    except Exception as e:
        logger.error(f"Failed to update {file_path}: {e}")
        return False


def update_all_system_outputs(system_outputs_dir: Path) -> Dict[str, int]:
    """Update all system output files in a directory.
    
    Args:
        system_outputs_dir: Directory containing system output files
        
    Returns:
        Dictionary with update statistics
    """
    stats = {"total": 0, "updated": 0, "failed": 0, "skipped": 0}
    
    # Find all JSON files
    json_files = list(system_outputs_dir.glob("*.json"))
    stats["total"] = len(json_files)
    
    if not json_files:
        logger.warning(f"No JSON files found in {system_outputs_dir}")
        return stats
    
    logger.info(f"Found {len(json_files)} system output files to update")
    
    for json_file in json_files:
        try:
            # Check if file is already updated
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            if 'models_used' in data and 'component_timings' in data:
                stats["skipped"] += 1
                continue
            
            # Update file
            if update_system_output_file(json_file):
                stats["updated"] += 1
            else:
                stats["failed"] += 1
                
        except Exception as e:
            logger.error(f"Error processing {json_file}: {e}")
            stats["failed"] += 1
    
    return stats


def main():
    """Main function to update system output files."""
    import argparse
    
    parser = argparse.ArgumentParser(description="Update system output files to new format")
    parser.add_argument("--system-outputs-dir", 
                       default="data/evaluation/system_outputs",
                       help="Directory containing system output files")
    parser.add_argument("--file", help="Update a specific file instead of all files")
    
    args = parser.parse_args()
    
    if args.file:
        # Update single file
        file_path = Path(args.file)
        if not file_path.exists():
            logger.error(f"File not found: {file_path}")
            return
        
        success = update_system_output_file(file_path)
        if success:
            logger.info("File updated successfully")
        else:
            logger.error("Failed to update file")
    else:
        # Update all files in directory
        system_outputs_dir = root / args.system_outputs_dir
        
        if not system_outputs_dir.exists():
            logger.error(f"Directory not found: {system_outputs_dir}")
            return
        
        stats = update_all_system_outputs(system_outputs_dir)
        
        logger.info("Update completed:")
        logger.info(f"  Total files: {stats['total']}")
        logger.info(f"  Updated: {stats['updated']}")
        logger.info(f"  Already updated (skipped): {stats['skipped']}")
        logger.info(f"  Failed: {stats['failed']}")


if __name__ == "__main__":
    main()