#!/usr/bin/env python3
"""Script to generate safety assessment for ground truth files using LLM analysis."""

import json
import sys
import hashlib
import time
from pathlib import Path
from typing import Dict, List, Any, Optional

# Add project root to path
root = Path(__file__).parent.parent
sys.path.append(str(root))

from src.llm.llms import get_llm
from src.utils.settings import settings
from src.utils.log import logger


def get_content_hash(annotation: str, violations: List[Dict], accidents: List[Dict]) -> str:
    """Generate hash of input content to detect changes.
    
    Args:
        annotation (str): Ground truth annotation.
        violations (List[Dict]): List of traffic violations.
        accidents (List[Dict]): List of accident risks.
        
    Returns:
        str: SHA256 hash of the content.
    """
    content = {
        'annotation': annotation,
        'violations': violations,
        'accidents': accidents
    }
    content_str = json.dumps(content, sort_keys=True, ensure_ascii=False)
    return hashlib.sha256(content_str.encode('utf-8')).hexdigest()[:16]


def load_cache(cache_dir: Path) -> Dict:
    """Load assessment cache from disk.
    
    Args:
        cache_dir (Path): Directory containing cache files.
        
    Returns:
        Dict: Cache data with file hashes and assessments.
    """
    cache_file = cache_dir / "assessment_cache.json"
    if cache_file.exists():
        try:
            with open(cache_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            logger.warning(f"Invalid cache file {cache_file}, starting fresh")
    
    return {
        "created_at": time.time(),
        "files": {}
    }


def save_cache(cache: Dict, cache_dir: Path) -> None:
    """Save assessment cache to disk.
    
    Args:
        cache (Dict): Cache data to save.
        cache_dir (Path): Directory to save cache files.
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_file = cache_dir / "assessment_cache.json"
    
    cache["updated_at"] = time.time()
    
    try:
        with open(cache_file, 'w', encoding='utf-8') as f:
            json.dump(cache, f, indent=2, ensure_ascii=False)
        logger.debug(f"Cache saved to {cache_file}")
    except Exception as e:
        logger.warning(f"Failed to save cache: {e}")


def get_cached_assessment(cache: Dict, file_name: str, content_hash: str) -> Optional[Dict]:
    """Get cached assessment if content hasn't changed.
    
    Args:
        cache (Dict): Cache data.
        file_name (str): Name of the ground truth file.
        content_hash (str): Hash of current content.
        
    Returns:
        Optional[Dict]: Cached assessment if available and valid, None otherwise.
    """
    if file_name in cache.get("files", {}):
        file_cache = cache["files"][file_name]
        if file_cache.get("content_hash") == content_hash:
            return file_cache.get("assessment")
    return None


def cache_assessment(cache: Dict, file_name: str, content_hash: str, assessment: Dict) -> None:
    """Cache an assessment result.
    
    Args:
        cache (Dict): Cache data.
        file_name (str): Name of the ground truth file.
        content_hash (str): Hash of content.
        assessment (Dict): Assessment result to cache.
    """
    if "files" not in cache:
        cache["files"] = {}
    
    cache["files"][file_name] = {
        "content_hash": content_hash,
        "assessment": assessment,
        "generated_at": time.time()
    }


def create_assessment_prompt(annotation: str, violations: List[Dict], accidents: List[Dict], video_path: str) -> str:
    """Create prompt for generating safety assessment.
    
    Args:
        annotation (str): Ground truth annotation.
        violations (List[Dict]): List of traffic violations.
        accidents (List[Dict]): List of accident risks.
        video_path (str): Path to the video file.
        
    Returns:
        str: Prompt for generating safety assessment.
    """
    # Count violations and accidents
    violations_found = sum(1 for v in violations if v.get('violation') == 'found')
    accidents_found = sum(1 for a in accidents if a.get('accident') == 'found')
    
    violations_text = ""
    if violations_found > 0:
        violations_text = "Traffic violations found:\n"
        for v in violations:
            if v.get('violation') == 'found':
                violations_text += f"- {v.get('scene', '')}: {v.get('reason', '')}\n"
    else:
        violations_text = "No traffic violations found.\n"
    
    accidents_text = ""
    if accidents_found > 0:
        accidents_text = "Accident risks identified:\n"
        for a in accidents:
            if a.get('accident') == 'found':
                accidents_text += f"- {a.get('scene', '')}: {a.get('consequence', '')}\n"
    else:
        accidents_text = "No significant accident risks identified.\n"
    
    return f"""You are an expert traffic safety analyst evaluating driving behavior from dashcam footage.

Video: {Path(video_path).name}

Driving Context:
{annotation}

{violations_text}

{accidents_text}

Your task is to generate a comprehensive safety assessment that evaluates the ego vehicle's driving performance. Based on the annotation and identified violations/accidents, provide:

1. **Safety Score (1-10 scale)**:
   - 1-3: Dangerous driving with major violations or high accident risk
   - 4-6: Moderate safety concerns with some violations or accident risks
   - 7-8: Generally safe driving with minor issues
   - 9-10: Excellent driving with no significant safety concerns

2. **Risk Level**: Choose from "low" (9-10 safety score), "medium" (7-8 safety score), "high" (4-6 safety score), or "critical" (1-3 safety score)

3. **Overall Evaluation**: A concise summary (1-2 sentences) of the driving performance

4. **Strengths**: List of positive driving behaviors observed (2-4 items)

5. **Weaknesses**: List of concerning behaviors or risk factors (2-4 items)

6. **Improvement Advice**: Specific, actionable recommendations (2-4 items)

Assessment Guidelines:
- Focus on the ego vehicle's behavior and decisions
- Consider both proactive safety measures and reactive responses
- Account for traffic violations and accident risks in scoring
- Be objective and factual in evaluation
- Provide constructive improvement suggestions
- Consider the complexity and difficulty of the driving situation

Respond in JSON format:
{{
  "safety_score": <number 1-10>,
  "risk_level": "<low/medium/high/critical>",
  "overall_evaluation": "<concise evaluation summary>",
  "strengths": [
    "<positive behavior 1>",
    "<positive behavior 2>",
    "<positive behavior 3>",
    "<positive behavior 4>"
  ],
  "weaknesses": [
    "<concerning behavior 1>",
    "<concerning behavior 2>",
    "<concerning behavior 3>",
    "<concerning behavior 4>"
  ],
  "improvement_advice": [
    "<specific recommendation 1>",
    "<specific recommendation 2>",
    "<specific recommendation 3>",
    "<specific recommendation 4>"
  ]
}}"""


def generate_safety_assessment(annotation: str, violations: List[Dict], accidents: List[Dict], 
                              video_path: str, model_id: Optional[str] = None) -> Dict:
    """Generate safety assessment using LLM.
    
    Args:
        annotation (str): Ground truth annotation.
        violations (List[Dict]): List of traffic violations.
        accidents (List[Dict]): List of accident risks.
        video_path (str): Path to the video file.
        model_id (str, optional): LLM model to use.
        
    Returns:
        Dict: Generated safety assessment.
    """
    # Get LLM
    if model_id:
        llm = get_llm(model_id)
    else:
        llm = get_llm(settings.app.llm['main'])
    
    # Generate prompt
    prompt = create_assessment_prompt(annotation, violations, accidents, video_path)
    
    # Generate assessment
    logger.debug("Generating safety assessment using LLM...")
    response = llm.invoke(prompt)
    
    try:
        assessment_data = json.loads(response.content.strip())
        return assessment_data
    except json.JSONDecodeError:
        logger.warning("Failed to parse LLM response as JSON, using fallback assessment")
        return {
            "safety_score": 5,
            "risk_level": "medium",
            "overall_evaluation": "Assessment generation failed - manual evaluation required",
            "strengths": ["Manual evaluation required"],
            "weaknesses": ["Manual evaluation required"],
            "improvement_advice": ["Manual evaluation required"]
        }


def process_ground_truth_file(ground_truth_file: Path, cache: Optional[Dict] = None, 
                              cache_dir: Optional[Path] = None, model_id: Optional[str] = None) -> bool:
    """Process a single ground truth file to generate assessment.
    
    Args:
        ground_truth_file (Path): Path to the ground truth JSON file.
        cache (Dict, optional): Assessment cache for change detection.
        cache_dir (Path, optional): Directory to save cache files.
        model_id (str, optional): LLM model to use.
        
    Returns:
        bool: True if file was updated, False if skipped or error occurred.
    """
    try:
        # Read ground truth file
        with open(ground_truth_file, 'r', encoding='utf-8') as f:
            gt_data = json.load(f)
        
        # Check if assessment already exists and is not a placeholder
        current_assessment = gt_data.get('ground_truth', {}).get('assessment', {})
        
        # Check if assessment has been manually edited (not a placeholder)
        if current_assessment:
            overall_eval = current_assessment.get('overall_evaluation', '')
            safety_score = current_assessment.get('safety_score', 0)
            strengths = current_assessment.get('strengths', [])
            weaknesses = current_assessment.get('weaknesses', [])
            
            # Check for non-placeholder content
            is_placeholder = (
                overall_eval in ["Manual evaluation of driving behavior", "Manual evaluation required", 
                               "Assessment generation failed - manual evaluation required"] or
                (isinstance(strengths, list) and len(strengths) == 1 and 
                 strengths[0] in ["Manual evaluation required", "Steer slightly left to avoid collision"]) or
                (isinstance(weaknesses, list) and len(weaknesses) == 1 and 
                 weaknesses[0] in ["Manual evaluation required"])
            )
            
            # Skip if assessment exists and is not a placeholder
            if not is_placeholder and overall_eval and safety_score > 0:
                logger.info(f"Skipping {ground_truth_file.name} - assessment already populated with manual content")
                return False
        
        # Extract required data
        gt = gt_data.get('ground_truth', {})
        annotation = gt.get('annotation', '')
        violations = gt.get('violations', [])
        accidents = gt.get('accidents', [])
        video_path = gt_data.get('video_path', '')
        
        if not annotation or annotation == "MANUAL_ANNOTATION_REQUIRED":
            logger.warning(f"No annotation found in {ground_truth_file.name} - please generate annotation first")
            return False
        
        if not violations:
            logger.warning(f"No violations data found in {ground_truth_file.name}")
            return False
            
        if not accidents:
            logger.warning(f"No accidents data found in {ground_truth_file.name}")
            return False
        
        # Generate content hash for change detection
        content_hash = get_content_hash(annotation, violations, accidents)
        
        # Check cache first if available
        assessment = None
        if cache is not None:
            cached_assessment = get_cached_assessment(cache, ground_truth_file.name, content_hash)
            if cached_assessment:
                logger.info(f"Using cached assessment for {ground_truth_file.name} (content unchanged)")
                assessment = cached_assessment
        
        # Generate new assessment if not cached
        if assessment is None:
            logger.info(f"Generating safety assessment for {ground_truth_file.name}")
            assessment = generate_safety_assessment(annotation, violations, accidents, video_path, model_id)
            
            # Cache the result if caching is enabled
            if cache is not None and assessment:
                cache_assessment(cache, ground_truth_file.name, content_hash, assessment)
                if cache_dir:
                    save_cache(cache, cache_dir)
        
        if not assessment:
            logger.error(f"Failed to generate assessment for {ground_truth_file.name}")
            return False
        
        # Update ground truth data
        gt_data['ground_truth']['assessment'] = assessment
        
        # Write updated file
        with open(ground_truth_file, 'w', encoding='utf-8') as f:
            json.dump(gt_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"✅ Generated assessment for {ground_truth_file.name}")
        logger.info(f"   Safety Score: {assessment.get('safety_score', 'N/A')}/10")
        logger.info(f"   Risk Level: {assessment.get('risk_level', 'N/A')}")
        logger.info(f"   Evaluation: {assessment.get('overall_evaluation', 'N/A')[:80]}...")
        
        return True
        
    except Exception as e:
        logger.error(f"❌ Failed to process {ground_truth_file}: {e}")
        return False


def generate_all_assessments(ground_truth_dir: Path, model_id: Optional[str] = None) -> None:
    """Generate assessments for all ground truth files that need them.
    
    Args:
        ground_truth_dir (Path): Directory containing ground truth files.
        model_id (str, optional): LLM model to use.
    """
    # Setup caching
    cache_dir = ground_truth_dir.parent / "cache"
    cache = load_cache(cache_dir)
    
    cache_stats = {"loaded": len(cache.get("files", {})), "hits": 0, "misses": 0}
    if not ground_truth_dir.exists():
        logger.error(f"Ground truth directory not found: {ground_truth_dir}")
        return
    
    # Find all ground truth JSON files
    json_files = list(ground_truth_dir.glob("*.json"))
    
    if not json_files:
        logger.warning(f"No JSON files found in {ground_truth_dir}")
        return
    
    logger.info(f"Found {len(json_files)} ground truth files")
    logger.info(f"Loaded {cache_stats['loaded']} cached assessments")
    logger.info(f"Using model: {model_id or settings.app.llm['main']}")
    logger.info("=" * 50)
    
    updated_count = 0
    skipped_count = 0
    failed_count = 0
    
    for json_file in json_files:
        logger.info(f"Processing: {json_file.name}")
        
        try:
            if process_ground_truth_file(json_file, cache, cache_dir, model_id):
                updated_count += 1
            else:
                skipped_count += 1
        except Exception as e:
            logger.error(f"Failed to process {json_file.name}: {e}")
            failed_count += 1
        
        logger.info("-" * 30)
    
    # Final cache save
    save_cache(cache, cache_dir)
    
    # Count cache usage
    final_cache_size = len(cache.get("files", {}))
    cache_hits = final_cache_size - cache_stats['loaded'] if updated_count > 0 else 0
    
    # Summary
    logger.info("=" * 50)
    logger.info(f"SUMMARY:")
    logger.info(f"✅ Generated: {updated_count} assessments")
    logger.info(f"⏭️ Skipped: {skipped_count} files") 
    logger.info(f"❌ Failed: {failed_count} files")
    logger.info(f"📁 Total: {len(json_files)} files")
    logger.info(f"💾 Cached: {final_cache_size} assessments")
    if cache_hits > 0:
        logger.info(f"⚡ Cache efficiency: Avoided {cache_hits} LLM calls")
    
    if updated_count > 0:
        logger.info("\\n📝 Next steps:")
        logger.info("1. Review the generated assessments in your IDE")
        logger.info("2. Edit safety scores, risk levels, and evaluations as needed")
        logger.info("3. Refine strengths, weaknesses, and improvement advice")
        logger.info("4. Run RAGAS evaluation when all annotations are complete")


def find_ground_truth_file(file_number: str, ground_truth_dir: Path) -> Optional[Path]:
    """Find ground truth file by number prefix.
    
    Args:
        file_number (str): File number (e.g., "000", "001").
        ground_truth_dir (Path): Directory containing ground truth files.
        
    Returns:
        Optional[Path]: Path to the found file, or None if not found.
    """
    # Try different patterns
    patterns = [
        f"{file_number}_*.json",
        f"{file_number}.json", 
        f"*{file_number}*.json"
    ]
    
    for pattern in patterns:
        matches = list(ground_truth_dir.glob(pattern))
        if matches:
            return matches[0]  # Return first match
    
    return None


def process_single_file(file_number: str, ground_truth_dir: Path, model_id: Optional[str] = None) -> None:
    """Process a single ground truth file by number.
    
    Args:
        file_number (str): File number to process.
        ground_truth_dir (Path): Directory containing ground truth files.
        model_id (str, optional): LLM model to use.
    """
    ground_truth_file = find_ground_truth_file(file_number, ground_truth_dir)
    
    if not ground_truth_file:
        logger.error(f"No ground truth file found with number: {file_number}")
        logger.error(f"Searched in: {ground_truth_dir}")
        
        # Show available files
        json_files = list(ground_truth_dir.glob("*.json"))
        if json_files:
            logger.info("\\n📁 Available files:")
            for f in json_files:
                logger.info(f"   {f.name}")
        return
    
    logger.info(f"Processing single file: {ground_truth_file.name}")
    logger.info(f"Using model: {model_id or settings.app.llm['main']}")
    logger.info("=" * 50)
    
    # Setup caching for single file mode too
    cache_dir = ground_truth_dir.parent / "cache"
    cache = load_cache(cache_dir)
    
    success = process_ground_truth_file(ground_truth_file, cache, cache_dir, model_id)
    
    logger.info("=" * 50)
    if success:
        logger.info("✅ Assessment generation completed successfully")
    else:
        logger.info("⏭️ Assessment generation skipped or failed")


def main():
    """Main function to generate safety assessments."""
    
    print("=" * 60)
    print("GENERATE SAFETY ASSESSMENTS FOR GROUND TRUTH FILES")
    print("=" * 60)
    
    # Configuration
    ground_truth_dir = root / "data" / "evaluation" / "ground_truth"
    
    # Parse command line arguments
    model_id = None
    file_number = None
    
    if len(sys.argv) > 1:
        if sys.argv[1] == '--help' or sys.argv[1] == '-h':
            print("Usage: python 2_5_generate_assessment.py [file_number] [--model=model_id]")
            print("Options:")
            print("  file_number         Process only the file with this number (e.g., '000')")
            print("  --model=MODEL_ID    Override default LLM model")
            print("")
            print("Features:")
            print("  ✅ Generates comprehensive safety assessments with scores and risk levels")
            print("  ✅ Analyzes violations and accident risks for scoring")
            print("  ✅ Provides strengths, weaknesses, and improvement advice")
            print("  ✅ Uses existing annotation, violations, and accidents data")
            print("  ⚡ Intelligent caching: Only processes changed files, skips unchanged content")
            print("")
            print("Examples:")
            print("  python 2_5_generate_assessment.py")
            print("  python 2_5_generate_assessment.py 000")
            print("  python 2_5_generate_assessment.py --model=\"openai:gpt-4o\"")
            print("  python 2_5_generate_assessment.py 001 --model=\"groq:llama-3.3-70b-versatile\"")
            return
        elif sys.argv[1].startswith('--model='):
            model_id = sys.argv[1].split('=')[1]
        else:
            file_number = sys.argv[1]
    
    if len(sys.argv) > 2:
        if sys.argv[2].startswith('--model='):
            model_id = sys.argv[2].split('=')[1]
    
    # Check directory exists
    if not ground_truth_dir.exists():
        print(f"❌ Ground truth directory not found: {ground_truth_dir}")
        print("Please ensure the data/evaluation/ground_truth directory exists")
        return
    
    print(f"📁 Processing files in: {ground_truth_dir}")
    
    # Show which model will be used
    effective_model = model_id or settings.app.llm['main']
    print(f"🤖 Using LLM model: {effective_model}")
    
    if model_id:
        print("   (Overridden from command line)")
    else:
        print("   (From settings.app.llm['main'])")
    
    print("📊 Will generate safety assessments based on existing data")
    print("💡 Requires annotations, violations, and accidents to be populated first")
    print()
    
    # Process files
    if file_number:
        print(f"🎯 Single file mode: {file_number}")
        process_single_file(file_number, ground_truth_dir, model_id)
    else:
        generate_all_assessments(ground_truth_dir, model_id)
    
    print()
    print("=" * 60)
    print("ASSESSMENT GENERATION COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()