"""LLM-as-judge evaluator implementation using GPT-4.1."""

import sys
import json
from pathlib import Path
from typing import Dict, Any, Optional

# Add project root to path
root = Path(__file__).parent.parent.parent.parent
sys.path.append(str(root))

from src.llm.llms import get_llm
from .judge_prompts import get_component_prompt
from ..config import EvaluationConfig


class LLMJudgeEvaluator:
    """LLM-as-judge evaluator using GPT-4.1 for component evaluation."""
    
    def __init__(self, model_id: str = "openai:gpt-4.1"):
        """Initialize LLM judge evaluator.
        
        Args:
            model_id: LLM model identifier for judging
        """
        try:
            self.base_llm = get_llm(model_id)  # Store base LLM without retry
            self.model_id = model_id
            print(f"✅ Initialized LLM Judge with {model_id}")
        except Exception as e:
            print(f"❌ Failed to initialize LLM Judge with {model_id}: {e}")
            raise
    
    def _extract_component_data(self, data: Dict[str, Any], component: str) -> Any:
        """Extract component-specific data from ground truth or system output.
        
        Args:
            data: Full data dictionary
            component: Component name
            
        Returns:
            Component-specific data
        """
        if component == 'annotation':
            # For annotation, extract the annotation text
            if 'ground_truth' in data:
                return data['ground_truth'].get('annotation', '')
            elif 'content' in data:
                return data['content']
            else:
                return str(data)
        
        elif component == 'scene':
            # For scene, extract the scenes list
            if 'ground_truth' in data:
                return data['ground_truth'].get('scenes', [])
            elif 'content' in data and isinstance(data['content'], list):
                return data['content']
            elif isinstance(data, list):
                return data
            else:
                return []
        
        elif component == 'violation':
            # For violation, extract violations data
            if 'ground_truth' in data:
                return data['ground_truth'].get('violations', [])
            elif 'content' in data:
                return data['content']
            else:
                return data
        
        elif component == 'accident':
            # For accident, extract accidents data
            if 'ground_truth' in data:
                return data['ground_truth'].get('accidents', [])
            elif 'content' in data:
                return data['content']
            else:
                return data
        
        elif component == 'assessment':
            # For assessment, extract assessment data
            if 'ground_truth' in data:
                return data['ground_truth'].get('assessment', {})
            elif 'content' in data:
                return data['content']
            else:
                return data
        
        else:
            return data
    
    def evaluate_component(self, component: str, ground_truth: Dict[str, Any], 
                          system_output: Dict[str, Any], video_id: str, 
                          model: str) -> Optional[Dict[str, Any]]:
        """Evaluate a component using LLM-as-judge.
        
        Args:
            component: Component name
            ground_truth: Ground truth data
            system_output: System output data
            video_id: Video identifier
            model: Model identifier being evaluated
            
        Returns:
            LLM judge evaluation result or None if failed
        """
        try:
            # Get component-specific prompt and score class
            prompt_template, score_class = get_component_prompt(component)
            
            # Extract component-specific data
            gt_data = self._extract_component_data(ground_truth, component)
            sys_data = self._extract_component_data(system_output, component)
            
            # Format data for display
            def format_data_for_prompt(data):
                if isinstance(data, dict):
                    return json.dumps(data, indent=2, ensure_ascii=False)
                elif isinstance(data, list):
                    if data and isinstance(data[0], dict):
                        return json.dumps(data, indent=2, ensure_ascii=False)
                    else:
                        return '\n'.join(f"- {item}" for item in data)
                else:
                    return str(data)
            
            # Create structured LLM (structured output before retry)
            structured_llm = self.base_llm.with_structured_output(score_class).with_retry()
            
            # Invoke evaluation
            result = structured_llm.invoke(
                prompt_template.format_messages(
                    ground_truth=format_data_for_prompt(gt_data),
                    system_output=format_data_for_prompt(sys_data),
                    video_id=video_id,
                    model=model
                )
            )
            
            # Convert result to dictionary and add metadata
            if isinstance(result, dict):
                evaluation_result = result
            else:
                # Handle case where result is a Pydantic model or similar
                evaluation_result = dict(result)
            
            # Calculate overall quality as weighted average
            scores = []
            weights = []
            
            # Extract numeric scores (skip reasoning fields)
            for key, value in evaluation_result.items():
                if key.endswith('_score') or key in ['accuracy_score', 'completeness_score', 'clarity_score',
                                                   'extraction_quality', 'temporal_coherence', 'safety_relevance',
                                                   'detection_accuracy', 'explanation_quality', 'legal_consistency',
                                                   'risk_assessment_accuracy', 'consequence_prediction', 'context_understanding',
                                                   'assessment_accuracy', 'advice_actionability', 'score_justification']:
                    if isinstance(value, (int, float)):
                        scores.append(float(value))
                        weights.append(1.0)
            
            if scores:
                evaluation_result['overall_quality'] = sum(s * w for s, w in zip(scores, weights)) / sum(weights)
            else:
                evaluation_result['overall_quality'] = 0.0
            
            # Add metadata
            evaluation_result.update({
                'component': component,
                'video_id': video_id,
                'evaluated_model': model,
                'judge_model': self.model_id,
                'evaluation_type': 'llm_judge'
            })
            
            return evaluation_result
            
        except Exception as e:
            print(f"❌ LLM judge evaluation failed for {component}/{model}/{video_id}: {e}")
            return None
    
    def batch_evaluate(self, component: str, evaluation_pairs: list) -> Dict[str, Any]:
        """Batch evaluate multiple ground truth vs system output pairs.
        
        Args:
            component: Component name
            evaluation_pairs: List of (ground_truth, system_output, video_id, model) tuples
            
        Returns:
            Dictionary of evaluation results keyed by video_id
        """
        results = {}
        
        print(f"🔍 LLM Judge evaluating {len(evaluation_pairs)} {component} pairs...")
        
        for i, (gt, sys_out, video_id, model) in enumerate(evaluation_pairs, 1):
            print(f"  [{i}/{len(evaluation_pairs)}] Evaluating {video_id} with {model}")
            
            result = self.evaluate_component(component, gt, sys_out, video_id, model)
            if result:
                results[f"{model}_{video_id}"] = result
            else:
                print(f"    ❌ Failed to evaluate {video_id}")
        
        success_rate = len(results) / len(evaluation_pairs) * 100
        print(f"✅ LLM Judge completed: {len(results)}/{len(evaluation_pairs)} successful ({success_rate:.1f}%)")
        
        return results