"""Component comparison framework for DriveGuard evaluation system.

This script provides tools for testing different models on specific components
and comparing their performance impact on the overall system.
"""

import sys
import json
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime

# Add project root to path
root = Path(__file__).parent.parent
sys.path.append(str(root))

from src.llm.workflow.evaluate_driving_video import DrivingVideoEvaluator
from src.utils.log import logger
from src.utils.settings import settings


def get_available_models() -> Dict[str, List[str]]:
    """Get available models for each component type.
    
    Returns:
        Dictionary mapping component types to available models
    """
    return {
        "llm": settings.supported_models.llm,
        "embedding": settings.supported_models.embedding
    }


def test_component_models(component_name: str, model_candidates: List[str], 
                         video_paths: List[str], output_dir: Path,
                         baseline_config: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
    """Test different models for a specific component.
    
    Args:
        component_name: Name of component to test (e.g., 'video_annotation')
        model_candidates: List of model keys to test
        video_paths: List of video files to test on
        output_dir: Directory to save results
        baseline_config: Baseline model configuration for other components
        
    Returns:
        Dictionary containing comparison results
    """
    logger.info(f"Testing component '{component_name}' with {len(model_candidates)} models")
    logger.info(f"Model candidates: {model_candidates}")
    logger.info(f"Testing on {len(video_paths)} videos")
    
    results = {}
    
    for model_key in model_candidates:
        logger.info(f"Testing model: {model_key}")
        
        # Create model overrides - change only the target component
        model_overrides = baseline_config.copy() if baseline_config else {}
        if component_name == "video_annotation":
            # For video annotation, use the model_id parameter
            evaluator = DrivingVideoEvaluator(model_id=model_key)
        else:
            # For other components, use model overrides
            model_overrides[component_name] = model_key
            evaluator = DrivingVideoEvaluator(model_overrides=model_overrides)
        
        model_results = {
            "model_key": model_key,
            "model_name": model_key,
            "component": component_name,
            "test_timestamp": datetime.utcnow().isoformat() + "Z",
            "video_results": []
        }
        
        for video_path in video_paths:
            try:
                logger.info(f"  Processing: {Path(video_path).name}")
                
                # Run evaluation
                detailed_results = evaluator.evaluate_with_details(video_path)
                
                # Extract component-specific metrics
                component_timing = detailed_results.get('component_timings', {}).get(component_name, 0)
                
                video_result = {
                    "video_id": Path(video_path).stem,
                    "video_path": str(video_path),
                    "component_timing": component_timing,
                    "total_time": detailed_results.get('evaluation_time', 0),
                    "safety_score": detailed_results.get('assessment', {}).get('safety_score', 0),
                    "risk_level": detailed_results.get('assessment', {}).get('risk_level', 'unknown'),
                    "models_used": detailed_results.get('models_used', {}),
                    "performance_metrics": detailed_results.get('performance_metrics', {})
                }
                
                model_results["video_results"].append(video_result)
                
            except Exception as e:
                logger.error(f"  Failed to process {video_path}: {e}")
                model_results["video_results"].append({
                    "video_id": Path(video_path).stem,
                    "video_path": str(video_path),
                    "error": str(e)
                })
        
        # Calculate aggregate metrics
        valid_results = [r for r in model_results["video_results"] if "error" not in r]
        if valid_results:
            model_results["aggregate_metrics"] = {
                "avg_component_timing": sum(r["component_timing"] for r in valid_results) / len(valid_results),
                "avg_total_time": sum(r["total_time"] for r in valid_results) / len(valid_results),
                "avg_safety_score": sum(r["safety_score"] for r in valid_results) / len(valid_results),
                "videos_processed": len(valid_results),
                "videos_failed": len(model_results["video_results"]) - len(valid_results)
            }
        
        results[model_key] = model_results
        
        # Save individual model results
        model_output_file = output_dir / f"{component_name}_{model_key}_results.json"
        with open(model_output_file, 'w', encoding='utf-8') as f:
            json.dump(model_results, f, indent=2, ensure_ascii=False)
        
        logger.info(f"  Completed testing {model_key}")
    
    # Create comparison summary
    comparison_summary = {
        "component_name": component_name,
        "test_timestamp": datetime.utcnow().isoformat() + "Z",
        "models_tested": model_candidates,
        "videos_tested": [Path(p).stem for p in video_paths],
        "results": results
    }
    
    # Save comparison results
    comparison_file = output_dir / f"{component_name}_comparison.json"
    with open(comparison_file, 'w', encoding='utf-8') as f:
        json.dump(comparison_summary, f, indent=2, ensure_ascii=False)
    
    logger.info(f"Component comparison completed. Results saved to: {comparison_file}")
    
    return comparison_summary


def generate_comparison_report(comparison_results: Dict[str, Any], output_file: Path):
    """Generate a human-readable comparison report.
    
    Args:
        comparison_results: Results from test_component_models
        output_file: Path to save the markdown report
    """
    component_name = comparison_results["component_name"]
    results = comparison_results["results"]
    
    report_lines = [
        f"# {component_name.title()} Component Model Comparison",
        "",
        f"**Test Date:** {comparison_results['test_timestamp']}",
        f"**Component:** {component_name}",
        f"**Models Tested:** {', '.join(comparison_results['models_tested'])}",
        f"**Videos Tested:** {', '.join(comparison_results['videos_tested'])}",
        "",
        "## Results Summary",
        "",
        "| Model | Avg Component Time (s) | Avg Total Time (s) | Avg Safety Score | Videos Processed |",
        "|-------|------------------------|-------------------|------------------|------------------|"
    ]
    
    # Sort models by average component timing
    sorted_models = []
    for model_key, model_result in results.items():
        if "aggregate_metrics" in model_result:
            metrics = model_result["aggregate_metrics"]
            sorted_models.append((
                model_key,
                model_result["model_name"],
                metrics["avg_component_timing"],
                metrics["avg_total_time"],
                metrics["avg_safety_score"],
                metrics["videos_processed"]
            ))
    
    sorted_models.sort(key=lambda x: x[2])  # Sort by component timing
    
    for model_key, model_name, comp_time, total_time, safety_score, videos in sorted_models:
        report_lines.append(
            f"| {model_name} | {comp_time:.2f} | {total_time:.2f} | {safety_score:.1f} | {videos} |"
        )
    
    report_lines.extend([
        "",
        "## Key Findings",
        ""
    ])
    
    if sorted_models:
        fastest_model = sorted_models[0]
        slowest_model = sorted_models[-1]
        
        report_lines.extend([
            f"- **Fastest Model**: {fastest_model[1]} ({fastest_model[2]:.2f}s)",
            f"- **Slowest Model**: {slowest_model[1]} ({slowest_model[2]:.2f}s)",
            f"- **Speed Difference**: {slowest_model[2] / fastest_model[2]:.1f}x slower",
            ""
        ])
        
        # Find best quality model
        best_quality = max(sorted_models, key=lambda x: x[4])
        report_lines.extend([
            f"- **Highest Quality**: {best_quality[1]} (Safety Score: {best_quality[4]:.1f})",
            ""
        ])
    
    report_lines.extend([
        "## Detailed Results",
        "",
        "```json",
        json.dumps(comparison_results, indent=2),
        "```"
    ])
    
    # Write report
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(report_lines))
    
    logger.info(f"Comparison report generated: {output_file}")


def main():
    """Main function for component comparison."""
    parser = argparse.ArgumentParser(description="DriveGuard Component Model Comparison")
    parser.add_argument("component", help="Component to test (e.g., 'video_annotation', 'scene_extraction')")
    parser.add_argument("--models", required=True, 
                       help="Comma-separated list of models to test (e.g., 'gpt4o,gpt4_turbo')")
    parser.add_argument("--videos", 
                       help="Comma-separated list of video paths (defaults to all videos in data/dashcam)")
    parser.add_argument("--output-dir", default="data/evaluation/component_experiments",
                       help="Directory to save results")
    parser.add_argument("--baseline-config", 
                       help="JSON string of baseline model configuration for other components")
    
    args = parser.parse_args()
    
    # Parse models
    model_candidates = [m.strip() for m in args.models.split(",")]
    
    # Validate models
    available_models = get_available_models()["llm"]
    for model in model_candidates:
        if model not in available_models:
            logger.error(f"Model '{model}' not available. Available models: {available_models}")
            return
    
    # Get video paths
    if args.videos:
        video_paths = [v.strip() for v in args.videos.split(",")]
    else:
        # Use default videos from data/dashcam
        video_dir = root / "data" / "dashcam"
        video_extensions = ['.mp4', '.avi', '.mov', '.mkv']
        video_paths = []
        for ext in video_extensions:
            video_paths.extend(str(p) for p in video_dir.glob(f"*{ext}"))
        
        if not video_paths:
            logger.error(f"No video files found in {video_dir}")
            return
    
    # Parse baseline config
    baseline_config = {}
    if args.baseline_config:
        try:
            baseline_config = json.loads(args.baseline_config)
        except json.JSONDecodeError as e:
            logger.error(f"Invalid JSON for baseline config: {e}")
            return
    
    # Create output directory
    output_dir = Path(args.output_dir) / args.component
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Run comparison
    results = test_component_models(
        component_name=args.component,
        model_candidates=model_candidates,
        video_paths=video_paths,
        output_dir=output_dir,
        baseline_config=baseline_config
    )
    
    # Generate report
    report_file = output_dir / f"{args.component}_comparison_report.md"
    generate_comparison_report(results, report_file)
    
    logger.info("Component comparison completed!")
    logger.info(f"Results directory: {output_dir}")
    logger.info(f"Comparison report: {report_file}")


if __name__ == "__main__":
    main()