"""
ViolationChecker system output generator.
"""

import json
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Add project root to path for imports
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.append(str(project_root))

from src.llm.agent.traffic_rule_checker import traffic_rule_agent
from src.llm.llms import get_llm
from .base_generator import BaseSystemGenerator


class ViolationChecker(BaseSystemGenerator):
    """Generator for ViolationChecker system outputs."""
    
    def __init__(self, config):
        """Initialize the ViolationChecker generator."""
        super().__init__(config)
        
        # Initialize the traffic rule agent (already compiled in traffic_rule_checker.py)
        self.traffic_rule_agent = traffic_rule_agent
        self.model_id = config.driveguard_model_id or config.model_id
        
    def get_component_name(self) -> str:
        """Return the component name."""
        return "violation"
    
    def get_ground_truth_list(self, video_filter: Optional[List[str]] = None) -> List[Path]:
        """Get list of ground truth files to process instead of video files."""
        ground_truth_dir = self.config.ground_truth_dir
        all_ground_truth = list(ground_truth_dir.glob("*.json"))
        all_ground_truth.sort()  # Ensure consistent ordering
        
        if video_filter:
            # Filter by video IDs (e.g., ["0000", "0001"])
            filtered_files = []
            for gt_file in all_ground_truth:
                video_id = gt_file.stem.split('_')[0]  # Extract ID from filename
                if video_id in video_filter:
                    filtered_files.append(gt_file)
            return filtered_files
        
        return all_ground_truth
    
    def get_ground_truth_path(self, video_path: Path) -> Path:
        """Get ground truth annotation file path for a video."""
        video_id = video_path.stem
        ground_truth_dir = self.config.project_root / "data" / "evaluation" / "ground_truth"
        ground_truth_path = ground_truth_dir / f"{video_id}.json"
        return ground_truth_path
    
    def generate_output_from_ground_truth(self, ground_truth_path: Path) -> Tuple[List[Dict], float]:
        """
        Generate violation checking output from ground truth scenes.
        
        Args:
            ground_truth_path: Path to the ground truth JSON file
            
        Returns:
            Tuple of (violation results list, total processing time)
        """
        start_time = time.time()
        
        try:
            with open(ground_truth_path, 'r') as f:
                ground_truth_data = json.load(f)
            
            # Extract scenes from ground truth
            if 'ground_truth' in ground_truth_data and 'scenes' in ground_truth_data['ground_truth']:
                scenes = ground_truth_data['ground_truth']['scenes']
            else:
                raise ValueError(f"Invalid ground truth format in {ground_truth_path} - missing scenes")
            
            # Process each scene individually through traffic_rule_agent
            violation_results = []
            scene_times = []
            
            for scene in scenes:
                scene_start = time.time()
                
                # Run scene through traffic rule agent
                result = self.traffic_rule_agent.invoke({'query': scene})
                
                scene_time = time.time() - scene_start
                scene_times.append(scene_time)
                
                violation_results.append({
                    'scene': scene,
                    'violation': result['result']['violation'],  # 'found' or 'not_found'
                    'reason': result['result']['reason'],
                    'processing_time': scene_time
                })
            
            total_time = time.time() - start_time
            return violation_results, total_time
            
        except Exception as e:
            raise Exception(f"Failed to generate violation checking: {e}")
    
    def generate_output(self, video_path: Path) -> List[Dict]:
        """
        Generate violation checking output from ground truth annotation.
        
        Args:
            video_path: Path to the video file (used to find corresponding ground truth)
            
        Returns:
            List of violation checking results
        """
        # Find corresponding ground truth file
        ground_truth_path = self.get_ground_truth_path(video_path)
        violation_results, _ = self.generate_output_from_ground_truth(ground_truth_path)
        return violation_results
    
    def create_output_metadata(
        self, 
        video_path: Path, 
        content: Any,
        generation_time: float,
        additional_metadata: Dict = None
    ) -> Dict[str, Any]:
        """Create output metadata with ViolationChecker-specific info."""
        
        # Count violations and scenes
        violation_count = 0
        scene_count = len(content) if isinstance(content, list) else 0
        avg_scene_time = 0.0
        
        if isinstance(content, list):
            violation_count = sum(1 for item in content if item.get('violation') == 'found')
            # Calculate average scene processing time
            scene_times = [item.get('processing_time', 0.0) for item in content]
            if scene_times:
                avg_scene_time = sum(scene_times) / len(scene_times)
        
        # Prepare additional metadata
        violation_metadata = {
            "scene_count": scene_count,
            "violation_count": violation_count,
            "avg_scene_processing_time": round(avg_scene_time, 3),
            "model_type": "text",
            "prompt_type": "traffic_rule_checking",
            "source": "ground_truth_scenes"
        }
        
        if additional_metadata:
            violation_metadata.update(additional_metadata)
        
        return super().create_output_metadata(
            video_path, 
            content, 
            generation_time, 
            violation_metadata
        )
    
    def process_videos(
        self, 
        video_filter: Optional[List[str]] = None,
        progress_callback: Optional[callable] = None
    ) -> Dict[str, Any]:
        """
        Process multiple ground truth files for violation checking.
        Override the base method to work with ground truth files instead of video files.
        """
        # Get list of ground truth files to process
        ground_truth_files = self.get_ground_truth_list(video_filter)
        
        # For each ground truth file, we need to create a corresponding "video path"
        # for the output file naming system to work correctly
        videos_to_process = []
        for gt_file in ground_truth_files:
            # Create a fake video path based on the ground truth filename
            video_name = gt_file.stem + ".mp4"  # Convert 0000_something.json -> 0000_something.mp4
            fake_video_path = self.config.dashcam_videos_dir / video_name
            
            # Check if we should process this file
            if self.config.should_process_video(fake_video_path):
                videos_to_process.append((fake_video_path, gt_file))
        
        self.stats["total_videos"] = len(ground_truth_files)
        self.stats["skipped"] = len(ground_truth_files) - len(videos_to_process)
        self.stats["start_time"] = self._get_current_time()
        
        print(f"Found {len(ground_truth_files)} total ground truth files")
        print(f"Processing {len(videos_to_process)} files")
        print(f"Skipping {self.stats['skipped']} existing files")
        print(f"Component: {self.get_component_name()}")
        print(f"Model: {self.config.model_id}")
        print()
        
        # Process each ground truth file
        for i, (video_path, gt_file) in enumerate(videos_to_process, 1):
            print(f"[{i}/{len(videos_to_process)}] Processing {gt_file.name}...")
            
            success = self.process_ground_truth_file(video_path, gt_file)
            
            if success:
                self.stats["processed"] += 1
            else:
                self.stats["failed"] += 1
            
            # Call progress callback if provided
            if progress_callback:
                progress_callback(i, len(videos_to_process), success)
        
        self.stats["end_time"] = self._get_current_time()
        
        # Print final statistics
        self.print_summary()
        
        return self.stats
    
    def process_ground_truth_file(self, video_path: Path, ground_truth_path: Path) -> bool:
        """
        Process a single ground truth file and save output.
        
        Args:
            video_path: Fake video path for output file naming
            ground_truth_path: Path to the ground truth file
            
        Returns:
            True if successful, False otherwise
        """
        try:
            # Generate the output from ground truth file (with timing)
            content, generation_time = self.generate_output_from_ground_truth(ground_truth_path)
            
            # Create standardized output using the fake video path for file naming
            output_data = self.create_output_metadata(
                video_path, 
                content, 
                generation_time
            )
            
            # Save to file
            if self.save_output(video_path, output_data):
                print(f"✓ Completed {ground_truth_path.name} in {generation_time:.2f}s")
                return True
            else:
                print(f"✗ Failed to save output for {ground_truth_path.name}")
                return False
                
        except Exception as e:
            print(f"✗ Error processing {ground_truth_path.name}: {e}")
            return False
    
    def _get_current_time(self) -> str:
        """Get current time as ISO string."""
        from datetime import datetime
        return datetime.now().isoformat()