"""Traditional evaluation metrics for component evaluation."""

import re
import sys
from pathlib import Path
from typing import Dict, List, Any, Optional, Union, Tuple
from collections import Counter
import numpy as np

# Add project root to path  
root = Path(__file__).parent.parent.parent.parent
sys.path.append(str(root))

# Try to import optional dependencies
try:
    from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
    from rouge_score import rouge_scorer
    NLTK_AVAILABLE = True
except ImportError:
    NLTK_AVAILABLE = False
    print("⚠️  NLTK/ROUGE not available. Using simplified text metrics.")

try:
    from sentence_transformers import SentenceTransformer
    from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    from sklearn.metrics.pairwise import cosine_similarity
    SENTENCE_TRANSFORMER_AVAILABLE = True
    SKLEARN_AVAILABLE = True
except ImportError as e:
    SENTENCE_TRANSFORMER_AVAILABLE = False
    SKLEARN_AVAILABLE = False
    print(f"⚠️  SentenceTransformers or sklearn not available: {e}")


class TextSimilarityMetrics:
    """Traditional text similarity metrics for annotation and scene evaluation."""
    
    def __init__(self):
        self.sentence_model = None
        if SENTENCE_TRANSFORMER_AVAILABLE:
            try:
                self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
            except Exception as e:
                print(f"⚠️  Failed to load sentence transformer: {e}")
    
    @staticmethod
    def simple_bleu(reference: str, candidate: str) -> float:
        """Simplified BLEU score implementation when NLTK is not available.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            BLEU score (0-1)
        """
        # Tokenize by splitting on whitespace and punctuation
        ref_tokens = re.findall(r'\w+', reference.lower())
        cand_tokens = re.findall(r'\w+', candidate.lower())
        
        if not cand_tokens:
            return 0.0
        
        # Simple unigram precision
        ref_counter = Counter(ref_tokens)
        cand_counter = Counter(cand_tokens)
        
        overlap = 0
        for token in cand_counter:
            overlap += min(cand_counter[token], ref_counter.get(token, 0))
        
        precision = overlap / len(cand_tokens)
        
        # Add brevity penalty
        ref_len = len(ref_tokens)
        cand_len = len(cand_tokens)
        brevity_penalty = min(1.0, np.exp(1 - ref_len / cand_len)) if cand_len > 0 else 0.0
        
        return precision * brevity_penalty
    
    def calculate_bleu(self, reference: str, candidate: str) -> float:
        """Calculate BLEU score.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            BLEU score (0-1)
        """
        if not NLTK_AVAILABLE:
            return self.simple_bleu(reference, candidate)
        
        try:
            ref_tokens = reference.lower().split()
            cand_tokens = candidate.lower().split()
            
            if not cand_tokens:
                return 0.0
            
            smoothing = SmoothingFunction().method1
            score = sentence_bleu([ref_tokens], cand_tokens, smoothing_function=smoothing)
            return float(score)
        except Exception as e:
            print(f"⚠️  BLEU calculation failed: {e}, using simple method")
            return self.simple_bleu(reference, candidate)
    
    @staticmethod
    def simple_rouge_l(reference: str, candidate: str) -> float:
        """Simplified ROUGE-L score when ROUGE is not available.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            ROUGE-L score (0-1)
        """
        ref_words = reference.lower().split()
        cand_words = candidate.lower().split()
        
        if not cand_words or not ref_words:
            return 0.0
        
        # Find LCS (Longest Common Subsequence)
        def lcs_length(seq1, seq2):
            m, n = len(seq1), len(seq2)
            dp = [[0] * (n + 1) for _ in range(m + 1)]
            
            for i in range(1, m + 1):
                for j in range(1, n + 1):
                    if seq1[i-1] == seq2[j-1]:
                        dp[i][j] = dp[i-1][j-1] + 1
                    else:
                        dp[i][j] = max(dp[i-1][j], dp[i][j-1])
            
            return dp[m][n]
        
        lcs_len = lcs_length(ref_words, cand_words)
        
        precision = lcs_len / len(cand_words)
        recall = lcs_len / len(ref_words)
        
        if precision + recall == 0:
            return 0.0
        
        f1 = 2 * precision * recall / (precision + recall)
        return float(f1)
    
    def calculate_rouge_l(self, reference: str, candidate: str) -> float:
        """Calculate ROUGE-L score.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            ROUGE-L score (0-1)
        """
        if not NLTK_AVAILABLE:
            return self.simple_rouge_l(reference, candidate)
        
        try:
            scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
            scores = scorer.score(reference, candidate)
            return float(scores['rougeL'].fmeasure)
        except Exception as e:
            print(f"⚠️  ROUGE calculation failed: {e}, using simple method")
            return self.simple_rouge_l(reference, candidate)
    
    def calculate_semantic_similarity(self, reference: str, candidate: str) -> float:
        """Calculate semantic similarity using sentence embeddings.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            Semantic similarity score (0-1)
        """
        if not self.sentence_model:
            # Fallback to simple word overlap
            ref_words = set(reference.lower().split())
            cand_words = set(candidate.lower().split())
            
            if not ref_words or not cand_words:
                return 0.0
            
            intersection = len(ref_words.intersection(cand_words))
            union = len(ref_words.union(cand_words))
            
            return intersection / union if union > 0 else 0.0
        
        try:
            embeddings = self.sentence_model.encode([reference, candidate])
            similarity = np.dot(embeddings[0], embeddings[1]) / (
                np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1])
            )
            return float(max(0.0, similarity))  # Ensure non-negative
        except Exception as e:
            print(f"⚠️  Semantic similarity calculation failed: {e}")
            return 0.0
    
    def evaluate_text_pair(self, reference: str, candidate: str) -> Dict[str, float]:
        """Evaluate a pair of texts using all available metrics.
        
        Args:
            reference: Reference text
            candidate: Candidate text
            
        Returns:
            Dictionary of metric scores
        """
        if not reference or not candidate:
            return {
                'bleu': 0.0,
                'rouge_l': 0.0,
                'semantic_similarity': 0.0
            }
        
        return {
            'bleu': self.calculate_bleu(reference, candidate),
            'rouge_l': self.calculate_rouge_l(reference, candidate),
            'semantic_similarity': self.calculate_semantic_similarity(reference, candidate)
        }


class ClassificationMetrics:
    """Enhanced classification metrics for violation and accident detection."""
    
    def __init__(self):
        """Initialize with sentence transformer for semantic similarity."""
        self.sentence_model = None
        if SENTENCE_TRANSFORMER_AVAILABLE:
            try:
                self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
            except Exception as e:
                print(f"⚠️  Failed to load sentence transformer: {e}")
    
    @staticmethod
    def calculate_binary_metrics(y_true: List[bool], y_pred: List[bool]) -> Dict[str, float]:
        """Calculate binary classification metrics.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            
        Returns:
            Dictionary of classification metrics
        """
        if len(y_true) != len(y_pred):
            raise ValueError("y_true and y_pred must have the same length")
        
        if not y_true:
            return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.0}
        
        # Calculate confusion matrix components
        tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt and yp)
        fp = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and yp)
        fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt and not yp)
        tn = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and not yp)
        
        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        accuracy = (tp + tn) / len(y_true)
        
        return {
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1),
            'accuracy': float(accuracy)
        }
    
    @staticmethod
    def parse_violation_detection(violation_data: Any) -> bool:
        """Parse violation detection result.
        
        Args:
            violation_data: Violation data (could be dict, string, or boolean)
            
        Returns:
            Boolean indicating if violation was detected
        """
        if isinstance(violation_data, bool):
            return violation_data
        elif isinstance(violation_data, dict):
            if 'violation' in violation_data:
                violation = violation_data['violation']
                if isinstance(violation, str):
                    return violation.lower() in ['found', 'true', 'yes', '1']
                return bool(violation)
            elif 'found' in violation_data:
                return bool(violation_data['found'])
        elif isinstance(violation_data, str):
            return violation_data.lower() in ['found', 'true', 'yes', '1']
        
        return False
    
    @staticmethod
    def parse_accident_detection(accident_data: Any) -> bool:
        """Parse accident detection result.
        
        Args:
            accident_data: Accident data (could be dict, string, or boolean)
            
        Returns:
            Boolean indicating if accident was detected
        """
        if isinstance(accident_data, bool):
            return accident_data
        elif isinstance(accident_data, dict):
            if 'accident' in accident_data:
                accident = accident_data['accident']
                if isinstance(accident, str):
                    return accident.lower() in ['found', 'true', 'yes', '1']
                return bool(accident)
            elif 'found' in accident_data:
                return bool(accident_data['found'])
        elif isinstance(accident_data, str):
            return accident_data.lower() in ['found', 'true', 'yes', '1']
        
        return False
    
    def calculate_semantic_classification_metrics(self, ground_truth: List[Dict[str, Any]], 
                                                predicted: List[Dict[str, Any]], 
                                                detection_type: str = 'violation') -> Dict[str, float]:
        """Calculate classification metrics using semantic similarity for reasoning.
        
        Args:
            ground_truth: List of ground truth violation/accident items
            predicted: List of predicted violation/accident items
            detection_type: Type of detection ('violation' or 'accident')
            
        Returns:
            Dictionary with semantic metrics
        """
        if not ground_truth or not predicted:
            return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.0, 'semantic_similarity': 0.0}
        
        try:
            # Align items by scene for comparison
            gt_by_scene = {}
            pred_by_scene = {}
            
            for item in ground_truth:
                if isinstance(item, dict) and 'scene' in item:
                    gt_by_scene[item['scene']] = item
            
            for item in predicted:
                if isinstance(item, dict) and 'scene' in item:
                    pred_by_scene[item['scene']] = item
            
            semantic_scores = []
            binary_gt = []
            binary_pred = []
            
            # Compare aligned scenes
            for scene in gt_by_scene.keys():
                if scene in pred_by_scene:
                    gt_item = gt_by_scene[scene]
                    pred_item = pred_by_scene[scene]
                    
                    # Parse binary labels
                    if detection_type == 'violation':
                        gt_binary = self.parse_violation_detection(gt_item)
                        pred_binary = self.parse_violation_detection(pred_item)
                    else:
                        gt_binary = self.parse_accident_detection(gt_item)
                        pred_binary = self.parse_accident_detection(pred_item)
                    
                    binary_gt.append(gt_binary)
                    binary_pred.append(pred_binary)
                    
                    # Semantic similarity for reasoning (if both found)
                    if gt_binary and pred_binary and self.sentence_model:
                        gt_reason = gt_item.get('reason', '').strip()
                        pred_reason = pred_item.get('reason', '').strip()
                        
                        if not gt_reason:
                            gt_reason = gt_item.get('consequence', '').strip()
                        if not pred_reason:
                            pred_reason = pred_item.get('consequence', '').strip()
                        
                        if gt_reason and pred_reason:
                            # Calculate semantic similarity
                            try:
                                gt_embedding = self.sentence_model.encode([gt_reason])
                                pred_embedding = self.sentence_model.encode([pred_reason])
                                similarity = cosine_similarity(gt_embedding.reshape(1, -1), 
                                                              pred_embedding.reshape(1, -1))[0][0]
                                semantic_scores.append(float(similarity))
                            except Exception as e:
                                print(f"⚠️  Semantic similarity calculation failed: {e}")
                                semantic_scores.append(0.5)  # Partial credit
                        else:
                            semantic_scores.append(0.5)  # Partial credit for detection without reasoning
                    elif gt_binary == pred_binary:
                        semantic_scores.append(1.0)  # Perfect match for both found or both not found
                    else:
                        semantic_scores.append(0.0)  # Mismatch
            
            # Calculate traditional metrics
            traditional_metrics = self.calculate_binary_metrics(binary_gt, binary_pred)
            
            # Add semantic metrics
            if semantic_scores:
                avg_semantic = np.mean(semantic_scores)
                traditional_metrics.update({
                    'semantic_similarity': float(avg_semantic),
                    'semantic_f1': float(avg_semantic * traditional_metrics['f1']),
                    'reasoning_quality': float(avg_semantic)
                })
            else:
                traditional_metrics.update({
                    'semantic_similarity': 0.0,
                    'semantic_f1': 0.0,
                    'reasoning_quality': 0.0
                })
            
            return traditional_metrics
            
        except Exception as e:
            print(f"⚠️  Semantic classification failed: {e}")
            # Fallback to simple binary classification
            gt_binary = [self.parse_violation_detection(item) if detection_type == 'violation' 
                        else self.parse_accident_detection(item) for item in ground_truth]
            pred_binary = [self.parse_violation_detection(item) if detection_type == 'violation' 
                          else self.parse_accident_detection(item) for item in predicted]
            
            metrics = self.calculate_binary_metrics(gt_binary, pred_binary)
            metrics.update({'semantic_similarity': 0.0, 'semantic_f1': 0.0, 'reasoning_quality': 0.0})
            return metrics


class StructuredMetrics:
    """Enhanced metrics for structured data comparison (assessments, scores)."""
    
    def __init__(self):
        """Initialize with sentence transformer for semantic similarity."""
        self.sentence_model = None
        if SENTENCE_TRANSFORMER_AVAILABLE:
            try:
                self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
            except Exception as e:
                print(f"⚠️  Failed to load sentence transformer: {e}")
    
    @staticmethod
    def calculate_score_correlation(true_scores: List[Union[int, float]], 
                                  pred_scores: List[Union[int, float]]) -> Dict[str, float]:
        """Calculate correlation between predicted and true scores.
        
        Args:
            true_scores: Ground truth scores
            pred_scores: Predicted scores
            
        Returns:
            Dictionary of correlation metrics
        """
        if len(true_scores) != len(pred_scores):
            raise ValueError("Score lists must have the same length")
        
        if len(true_scores) < 2:
            return {'pearson': 0.0, 'spearman': 0.0, 'mae': 0.0, 'rmse': 0.0}
        
        # Convert to numpy arrays
        true_arr = np.array(true_scores, dtype=float)
        pred_arr = np.array(pred_scores, dtype=float)
        
        # Calculate correlations
        pearson_corr = np.corrcoef(true_arr, pred_arr)[0, 1]
        if np.isnan(pearson_corr):
            pearson_corr = 0.0
        
        # Spearman correlation (rank-based)
        true_ranks = np.argsort(np.argsort(true_arr))
        pred_ranks = np.argsort(np.argsort(pred_arr))
        spearman_corr = np.corrcoef(true_ranks, pred_ranks)[0, 1]
        if np.isnan(spearman_corr):
            spearman_corr = 0.0
        
        # Error metrics
        mae = np.mean(np.abs(true_arr - pred_arr))
        rmse = np.sqrt(np.mean((true_arr - pred_arr) ** 2))
        
        return {
            'pearson': float(pearson_corr),
            'spearman': float(spearman_corr),
            'mae': float(mae),
            'rmse': float(rmse)
        }
    
    @staticmethod
    def calculate_risk_level_accuracy(true_levels: List[str], pred_levels: List[str]) -> Dict[str, float]:
        """Calculate multi-class accuracy for risk levels.
        
        Args:
            true_levels: Ground truth risk levels
            pred_levels: Predicted risk levels
            
        Returns:
            Dictionary of accuracy metrics
        """
        if len(true_levels) != len(pred_levels):
            raise ValueError("Risk level lists must have the same length")
        
        if not true_levels:
            return {'accuracy': 0.0, 'class_accuracy': {}}
        
        # Overall accuracy
        correct = sum(1 for true, pred in zip(true_levels, pred_levels) if true == pred)
        accuracy = correct / len(true_levels)
        
        # Per-class accuracy
        classes = set(true_levels)
        class_accuracy = {}
        
        for cls in classes:
            true_cls = [t == cls for t in true_levels]
            pred_cls = [p == cls for p in pred_levels]
            
            if sum(true_cls) > 0:
                tp = sum(1 for t, p in zip(true_cls, pred_cls) if t and p)
                class_accuracy[cls] = tp / sum(true_cls)
            else:
                class_accuracy[cls] = 0.0
        
        return {
            'accuracy': float(accuracy),
            'class_accuracy': class_accuracy
        }
    
    def calculate_assessment_metrics(self, ground_truth: Dict[str, Any], 
                                   predicted: Dict[str, Any]) -> Dict[str, float]:
        """Calculate comprehensive assessment metrics.
        
        Args:
            ground_truth: Ground truth assessment data
            predicted: Predicted assessment data
            
        Returns:
            Dictionary with assessment metrics
        """
        metrics = {}
        
        try:
            # Safety score metrics (MAE, MSE)
            if 'safety_score' in ground_truth and 'safety_score' in predicted:
                gt_score = ground_truth['safety_score']
                pred_score = predicted['safety_score']
                
                if isinstance(gt_score, (int, float)) and isinstance(pred_score, (int, float)):
                    if SKLEARN_AVAILABLE:
                        mae = mean_absolute_error([gt_score], [pred_score])
                        mse = mean_squared_error([gt_score], [pred_score])
                        metrics['safety_score_mae'] = float(mae)
                        metrics['safety_score_mse'] = float(mse)
                        metrics['safety_score_rmse'] = float(np.sqrt(mse))
                    else:
                        # Fallback implementation
                        diff = abs(gt_score - pred_score)
                        metrics['safety_score_mae'] = float(diff)
                        metrics['safety_score_mse'] = float(diff ** 2)
                        metrics['safety_score_rmse'] = float(diff)
            
            # Risk level accuracy
            if 'risk_level' in ground_truth and 'risk_level' in predicted:
                gt_risk = ground_truth['risk_level']
                pred_risk = predicted['risk_level']
                
                if isinstance(gt_risk, str) and isinstance(pred_risk, str):
                    risk_acc = self.calculate_risk_level_accuracy([gt_risk], [pred_risk])
                    metrics['risk_level_accuracy'] = risk_acc['accuracy']
            
            # Semantic similarity for advice/evaluation
            if self.sentence_model:
                # Overall evaluation similarity
                if 'overall_evaluation' in ground_truth and 'overall_evaluation' in predicted:
                    gt_eval = ground_truth['overall_evaluation']
                    pred_eval = predicted['overall_evaluation']
                    
                    if isinstance(gt_eval, str) and isinstance(pred_eval, str) and gt_eval.strip() and pred_eval.strip():
                        try:
                            gt_emb = self.sentence_model.encode([gt_eval])
                            pred_emb = self.sentence_model.encode([pred_eval])
                            sim = cosine_similarity(gt_emb.reshape(1, -1), pred_emb.reshape(1, -1))[0][0]
                            metrics['evaluation_similarity'] = float(sim)
                        except Exception as e:
                            print(f"⚠️  Evaluation similarity failed: {e}")
                
                # Improvement advice similarity
                if 'improvement_advice' in ground_truth and 'improvement_advice' in predicted:
                    gt_advice = ground_truth['improvement_advice']
                    pred_advice = predicted['improvement_advice']
                    
                    # Handle both list and string formats
                    if isinstance(gt_advice, list):
                        gt_advice = ' '.join(gt_advice)
                    if isinstance(pred_advice, list):
                        pred_advice = ' '.join(pred_advice)
                    
                    if isinstance(gt_advice, str) and isinstance(pred_advice, str) and gt_advice.strip() and pred_advice.strip():
                        try:
                            gt_emb = self.sentence_model.encode([gt_advice])
                            pred_emb = self.sentence_model.encode([pred_advice])
                            sim = cosine_similarity(gt_emb.reshape(1, -1), pred_emb.reshape(1, -1))[0][0]
                            metrics['advice_similarity'] = float(sim)
                        except Exception as e:
                            print(f"⚠️  Advice similarity failed: {e}")
                
                # Strengths and weaknesses coverage
                if 'strengths' in ground_truth and 'strengths' in predicted:
                    gt_strengths = ground_truth['strengths']
                    pred_strengths = predicted['strengths']
                    
                    if isinstance(gt_strengths, list) and isinstance(pred_strengths, list):
                        gt_str = ' '.join(gt_strengths) if gt_strengths else ''
                        pred_str = ' '.join(pred_strengths) if pred_strengths else ''
                        
                        if gt_str and pred_str:
                            try:
                                gt_emb = self.sentence_model.encode([gt_str])
                                pred_emb = self.sentence_model.encode([pred_str])
                                sim = cosine_similarity(gt_emb.reshape(1, -1), pred_emb.reshape(1, -1))[0][0]
                                metrics['strengths_similarity'] = float(sim)
                            except Exception as e:
                                print(f"⚠️  Strengths similarity failed: {e}")
                
                if 'weaknesses' in ground_truth and 'weaknesses' in predicted:
                    gt_weaknesses = ground_truth['weaknesses']
                    pred_weaknesses = predicted['weaknesses']
                    
                    if isinstance(gt_weaknesses, list) and isinstance(pred_weaknesses, list):
                        gt_str = ' '.join(gt_weaknesses) if gt_weaknesses else ''
                        pred_str = ' '.join(pred_weaknesses) if pred_weaknesses else ''
                        
                        if gt_str and pred_str:
                            try:
                                gt_emb = self.sentence_model.encode([gt_str])
                                pred_emb = self.sentence_model.encode([pred_str])
                                sim = cosine_similarity(gt_emb.reshape(1, -1), pred_emb.reshape(1, -1))[0][0]
                                metrics['weaknesses_similarity'] = float(sim)
                            except Exception as e:
                                print(f"⚠️  Weaknesses similarity failed: {e}")
            
            # Completeness metrics
            if 'strengths' in ground_truth and 'strengths' in predicted:
                gt_count = len(ground_truth['strengths']) if isinstance(ground_truth['strengths'], list) else 0
                pred_count = len(predicted['strengths']) if isinstance(predicted['strengths'], list) else 0
                
                if gt_count > 0:
                    metrics['strengths_coverage'] = min(pred_count / gt_count, 1.0)
                else:
                    metrics['strengths_coverage'] = 1.0 if pred_count == 0 else 0.0
            
            if 'weaknesses' in ground_truth and 'weaknesses' in predicted:
                gt_count = len(ground_truth['weaknesses']) if isinstance(ground_truth['weaknesses'], list) else 0
                pred_count = len(predicted['weaknesses']) if isinstance(predicted['weaknesses'], list) else 0
                
                if gt_count > 0:
                    metrics['weaknesses_coverage'] = min(pred_count / gt_count, 1.0)
                else:
                    metrics['weaknesses_coverage'] = 1.0 if pred_count == 0 else 0.0
            
            if 'improvement_advice' in ground_truth and 'improvement_advice' in predicted:
                gt_count = len(ground_truth['improvement_advice']) if isinstance(ground_truth['improvement_advice'], list) else 0
                pred_count = len(predicted['improvement_advice']) if isinstance(predicted['improvement_advice'], list) else 0
                
                if gt_count > 0:
                    metrics['advice_coverage'] = min(pred_count / gt_count, 1.0)
                else:
                    metrics['advice_coverage'] = 1.0 if pred_count == 0 else 0.0
            
        except Exception as e:
            print(f"⚠️  Assessment metrics calculation failed: {e}")
        
        return metrics


class DrivingSafetyMetrics:
    """Domain-specific metrics for driving safety evaluation."""
    
    def __init__(self):
        """Initialize with driving safety knowledge."""
        # Define violation severity levels based on driving safety impact
        self.violation_severity = {
            # Critical violations (high collision risk)
            'lane_change': 'critical',
            'cutting_off': 'critical', 
            'unsafe_merging': 'critical',
            'wrong_way': 'critical',
            'collision': 'critical',
            'hit_and_run': 'critical',
            
            # Moderate violations (moderate risk)
            'speeding': 'moderate',
            'following_too_close': 'moderate',
            'tailgating': 'moderate',
            'failure_to_yield': 'moderate',
            'red_light': 'moderate',
            'stop_sign': 'moderate',
            
            # Low violations (minor safety issues)
            'turn_signal': 'low',
            'parking': 'low',
            'lane_discipline': 'low'
        }
        
        # Define accident severity levels
        self.accident_severity = {
            'collision_accident': 'critical',
            'near_miss': 'moderate', 
            'risky_behavior': 'moderate',
            'environmental_hazard': 'low',
            'safe_maneuver': 'low'
        }
        
        # Severity weights for scoring
        self.severity_weights = {
            'critical': 1.0,
            'moderate': 0.6,
            'low': 0.3
        }
    
    def _extract_safety_keywords(self, text: str) -> List[str]:
        """Extract safety-related keywords from text.
        
        Args:
            text: Text to analyze (violation reason, accident consequence, etc.)
            
        Returns:
            List of identified safety keywords
        """
        if not text:
            return []
        
        text_lower = text.lower()
        keywords = []
        
        # Check for violation keywords
        for keyword, severity in self.violation_severity.items():
            if keyword.replace('_', ' ') in text_lower:
                keywords.append(keyword)
        
        # Check for accident keywords  
        for keyword, severity in self.accident_severity.items():
            if keyword.replace('_', ' ') in text_lower:
                keywords.append(keyword)
        
        return keywords
    
    def calculate_safety_criticality_score(self, violations: List[Dict[str, Any]], 
                                          accidents: List[Dict[str, Any]]) -> Dict[str, float]:
        """Calculate weighted safety criticality score based on violation/accident severity.
        
        Args:
            violations: List of violation detections with reasons
            accidents: List of accident assessments with consequences
            
        Returns:
            Dictionary with safety criticality metrics
        """
        if not violations and not accidents:
            return {'safety_criticality': 0.0, 'critical_events': 0, 'total_events': 0}
        
        critical_score = 0.0
        total_events = 0
        critical_events = 0
        
        # Analyze violations
        for violation in violations:
            if isinstance(violation, dict):
                total_events += 1
                reason = violation.get('reason', '')
                keywords = self._extract_safety_keywords(reason)
                
                if keywords:
                    # Get maximum severity from detected keywords
                    severities = [self.violation_severity.get(kw, 'low') for kw in keywords]
                    max_severity = 'critical' if 'critical' in severities else ('moderate' if 'moderate' in severities else 'low')
                    weight = self.severity_weights[max_severity]
                    critical_score += weight
                    
                    if max_severity == 'critical':
                        critical_events += 1
        
        # Analyze accidents
        for accident in accidents:
            if isinstance(accident, dict):
                total_events += 1
                consequence = accident.get('consequence', '')
                keywords = self._extract_safety_keywords(consequence)
                
                if keywords:
                    # Get maximum severity from detected keywords
                    severities = [self.accident_severity.get(kw, 'low') for kw in keywords]
                    max_severity = 'critical' if 'critical' in severities else ('moderate' if 'moderate' in severities else 'low')
                    weight = self.severity_weights[max_severity]
                    critical_score += weight
                    
                    if max_severity == 'critical':
                        critical_events += 1
        
        # Normalize score
        avg_criticality = critical_score / total_events if total_events > 0 else 0.0
        
        return {
            'safety_criticality': float(avg_criticality),
            'critical_events': critical_events,
            'total_events': total_events,
            'critical_event_ratio': critical_events / total_events if total_events > 0 else 0.0
        }
    
    def calculate_temporal_causality_score(self, violations: List[Dict[str, Any]], 
                                          accidents: List[Dict[str, Any]], 
                                          assessment: Dict[str, Any]) -> Dict[str, float]:
        """Calculate how well the system detects causal relationships between violations → accidents → assessment.
        
        Args:
            violations: List of violation detections
            accidents: List of accident assessments  
            assessment: Final driving assessment
            
        Returns:
            Dictionary with causality metrics
        """
        causality_score = 0.0
        consistency_checks = 0
        
        # Extract safety score from assessment
        safety_score = assessment.get('safety_score', 5)  # Default to neutral
        risk_level = assessment.get('risk_level', 'medium')
        
        # Count critical violations and accidents
        violation_count = sum(1 for v in violations if v.get('violation') == 'found')
        accident_count = sum(1 for a in accidents if a.get('accident') == 'found')
        
        # Check causality: violations → accidents
        if violation_count > 0 and accident_count > 0:
            causality_score += 0.4  # Violations lead to accidents
            consistency_checks += 1
        elif violation_count > 0 and accident_count == 0:
            causality_score += 0.2  # Some violations may not lead to accidents
            consistency_checks += 1
        elif violation_count == 0 and accident_count > 0:
            causality_score += 0.1  # Accidents without detected violations (external causes)
            consistency_checks += 1
        
        # Check causality: accidents → safety score
        if accident_count > 0:
            consistency_checks += 1
            if safety_score <= 3:  # Low safety score with accidents
                causality_score += 0.3
            elif safety_score >= 7:  # High safety score despite accidents (inconsistent)
                causality_score += 0.05
            else:  # Medium safety score with accidents
                causality_score += 0.2
        
        # Check causality: safety score → risk level
        if safety_score is not None:
            consistency_checks += 1
            if safety_score <= 3 and risk_level in ['critical', 'high']:
                causality_score += 0.3  # Low score → high risk (consistent)
            elif safety_score >= 7 and risk_level in ['low', 'minimal']:
                causality_score += 0.3  # High score → low risk (consistent)
            elif 4 <= safety_score <= 6 and risk_level in ['medium', 'moderate']:
                causality_score += 0.3  # Medium score → medium risk (consistent)
            else:
                causality_score += 0.1  # Inconsistent score-risk mapping
        
        # Normalize by number of consistency checks performed
        final_score = causality_score / consistency_checks if consistency_checks > 0 else 0.0
        
        return {
            'temporal_causality': float(final_score),
            'violation_accident_consistency': float(violation_count > 0 and accident_count > 0),
            'accident_assessment_consistency': float(accident_count > 0 and safety_score <= 5),
            'overall_causal_chain': float(final_score)
        }
    
    def calculate_comprehensive_safety_metrics(self, violations: List[Dict[str, Any]], 
                                             accidents: List[Dict[str, Any]], 
                                             assessment: Dict[str, Any]) -> Dict[str, float]:
        """Calculate comprehensive safety metrics combining criticality and causality.
        
        Args:
            violations: List of violation detections
            accidents: List of accident assessments
            assessment: Final driving assessment
            
        Returns:
            Dictionary with comprehensive safety metrics
        """
        # Calculate individual metric components
        criticality_metrics = self.calculate_safety_criticality_score(violations, accidents)
        causality_metrics = self.calculate_temporal_causality_score(violations, accidents, assessment)
        
        # Combine metrics
        comprehensive_metrics = {}
        comprehensive_metrics.update(criticality_metrics)
        comprehensive_metrics.update(causality_metrics)
        
        # Calculate overall safety evaluation quality
        safety_quality = (
            criticality_metrics['safety_criticality'] * 0.4 +
            causality_metrics['temporal_causality'] * 0.4 +
            criticality_metrics['critical_event_ratio'] * 0.2
        )
        
        comprehensive_metrics['overall_safety_quality'] = float(safety_quality)
        
        return comprehensive_metrics


class SceneEvaluationMetrics:
    """Enhanced scene evaluation metrics with temporal order and coherence."""
    
    def __init__(self):
        """Initialize scene evaluation with semantic models."""
        # Critical scene types that require special attention
        self.critical_scene_types = {
            'collision', 'accident', 'cut_off', 'near_miss', 
            'unsafe_lane_change', 'tailgating', 'emergency_braking'
        }
        
        # Semantic transition words indicating temporal flow
        self.temporal_indicators = {
            'first': 0, 'initially': 0, 'starts': 0, 'begins': 0,
            'then': 1, 'next': 1, 'subsequently': 1, 'after': 1,
            'meanwhile': 1, 'simultaneously': 1, 'during': 1,
            'finally': 2, 'eventually': 2, 'ultimately': 2, 'ends': 2
        }
    
    def _extract_temporal_indicators(self, scene_text: str) -> List[int]:
        """Extract temporal order indicators from scene description.
        
        Args:
            scene_text: Scene description text
            
        Returns:
            List of temporal order scores (0=start, 1=middle, 2=end)
        """
        if not scene_text:
            return []
        
        text_lower = scene_text.lower()
        indicators = []
        
        for word, order in self.temporal_indicators.items():
            if word in text_lower:
                indicators.append(order)
        
        return indicators
    
    def calculate_temporal_order_accuracy(self, predicted_scenes: List[str], 
                                        ground_truth_scenes: List[str]) -> Dict[str, float]:
        """Calculate temporal order accuracy between predicted and ground truth scenes.
        
        Args:
            predicted_scenes: List of predicted scene descriptions
            ground_truth_scenes: List of ground truth scene descriptions
            
        Returns:
            Dictionary with temporal order metrics
        """
        if not predicted_scenes or not ground_truth_scenes:
            return {
                'temporal_order_accuracy': 0.0,
                'sequence_similarity': 0.0,
                'order_preservation': 0.0
            }
        
        # Extract temporal indicators from both sets
        pred_indicators = []
        gt_indicators = []
        
        for scene in predicted_scenes:
            indicators = self._extract_temporal_indicators(scene)
            pred_indicators.extend(indicators)
        
        for scene in ground_truth_scenes:
            indicators = self._extract_temporal_indicators(scene)
            gt_indicators.extend(indicators)
        
        # Calculate sequence similarity
        if not pred_indicators or not gt_indicators:
            sequence_similarity = 0.5  # Neutral if no temporal indicators found
        else:
            # Compare temporal sequences
            min_len = min(len(pred_indicators), len(gt_indicators))
            matches = sum(1 for i in range(min_len) if pred_indicators[i] == gt_indicators[i])
            sequence_similarity = matches / min_len if min_len > 0 else 0.0
        
        # Calculate order preservation (are sequences monotonically increasing?)
        pred_ordered = all(pred_indicators[i] <= pred_indicators[i+1] 
                          for i in range(len(pred_indicators)-1)) if len(pred_indicators) > 1 else True
        gt_ordered = all(gt_indicators[i] <= gt_indicators[i+1] 
                        for i in range(len(gt_indicators)-1)) if len(gt_indicators) > 1 else True
        
        order_preservation = float(pred_ordered and gt_ordered)
        
        # Overall temporal accuracy
        temporal_accuracy = (sequence_similarity + order_preservation) / 2
        
        return {
            'temporal_order_accuracy': float(temporal_accuracy),
            'sequence_similarity': float(sequence_similarity),
            'order_preservation': float(order_preservation)
        }
    
    def calculate_critical_scene_detection(self, predicted_scenes: List[str], 
                                          ground_truth_scenes: List[str]) -> Dict[str, float]:
        """Calculate critical scene detection accuracy with safety-based weighting.
        
        Args:
            predicted_scenes: List of predicted scene descriptions
            ground_truth_scenes: List of ground truth scene descriptions
            
        Returns:
            Dictionary with critical scene detection metrics
        """
        if not predicted_scenes or not ground_truth_scenes:
            return {
                'critical_scene_recall': 0.0,
                'critical_scene_precision': 0.0,
                'critical_scene_f1': 0.0,
                'safety_weighted_detection': 0.0
            }
        
        # Identify critical scenes in both sets
        pred_critical = []
        gt_critical = []
        
        for i, scene in enumerate(predicted_scenes):
            is_critical = any(critical_type in scene.lower() 
                            for critical_type in self.critical_scene_types)
            if is_critical:
                pred_critical.append(i)
        
        for i, scene in enumerate(ground_truth_scenes):
            is_critical = any(critical_type in scene.lower() 
                            for critical_type in self.critical_scene_types)
            if is_critical:
                gt_critical.append(i)
        
        # Calculate precision, recall, F1 for critical scene detection
        if not gt_critical:
            # No critical scenes in ground truth
            critical_recall = 1.0 if not pred_critical else 0.0
            critical_precision = 1.0 if not pred_critical else 0.0
        else:
            # Calculate intersection based on scene position alignment
            true_positives = len(set(pred_critical) & set(gt_critical))
            
            critical_recall = true_positives / len(gt_critical)
            critical_precision = true_positives / len(pred_critical) if pred_critical else 0.0
        
        critical_f1 = (2 * critical_precision * critical_recall / 
                      (critical_precision + critical_recall)) if (critical_precision + critical_recall) > 0 else 0.0
        
        # Safety-weighted detection score (higher weight for critical scenes)
        total_scenes = len(ground_truth_scenes)
        critical_weight = len(gt_critical) / total_scenes if total_scenes > 0 else 0.0
        safety_weighted = critical_f1 * (1 + critical_weight)  # Boost score if many critical scenes
        
        return {
            'critical_scene_recall': float(critical_recall),
            'critical_scene_precision': float(critical_precision),
            'critical_scene_f1': float(critical_f1),
            'safety_weighted_detection': float(min(safety_weighted, 1.0))  # Cap at 1.0
        }
    
    def calculate_scene_coherence(self, predicted_scenes: List[str]) -> Dict[str, float]:
        """Calculate semantic coherence of scene sequences.
        
        Args:
            predicted_scenes: List of predicted scene descriptions
            
        Returns:
            Dictionary with coherence metrics
        """
        if len(predicted_scenes) < 2:
            return {
                'scene_coherence': 1.0,  # Single scene is perfectly coherent
                'semantic_transitions': 1.0,
                'narrative_flow': 1.0
            }
        
        # Simple coherence check based on semantic transitions
        coherence_score = 0.0
        transition_score = 0.0
        
        for i in range(len(predicted_scenes) - 1):
            current_scene = predicted_scenes[i].lower()
            next_scene = predicted_scenes[i + 1].lower()
            
            # Check for logical flow (basic heuristics)
            # 1. Vehicle mentioned in current scene should be referenced in next
            vehicles = ['ego vehicle', 'sedan', 'suv', 'truck', 'car', 'vehicle']
            current_vehicles = [v for v in vehicles if v in current_scene]
            next_vehicles = [v for v in vehicles if v in next_scene]
            
            if current_vehicles and next_vehicles:
                vehicle_continuity = len(set(current_vehicles) & set(next_vehicles)) / len(set(current_vehicles))
                coherence_score += vehicle_continuity
            else:
                coherence_score += 0.5  # Neutral score if no vehicle references
            
            # 2. Check for temporal flow indicators
            has_temporal = any(indicator in next_scene for indicator in self.temporal_indicators.keys())
            if has_temporal:
                transition_score += 1.0
            else:
                transition_score += 0.5
        
        # Normalize scores
        num_transitions = len(predicted_scenes) - 1
        avg_coherence = coherence_score / num_transitions if num_transitions > 0 else 1.0
        avg_transitions = transition_score / num_transitions if num_transitions > 0 else 1.0
        
        # Overall narrative flow
        narrative_flow = (avg_coherence + avg_transitions) / 2
        
        return {
            'scene_coherence': float(min(avg_coherence, 1.0)),
            'semantic_transitions': float(min(avg_transitions, 1.0)),
            'narrative_flow': float(min(narrative_flow, 1.0))
        }
    
    def calculate_comprehensive_scene_metrics(self, predicted_scenes: List[str], 
                                            ground_truth_scenes: List[str]) -> Dict[str, float]:
        """Calculate comprehensive scene evaluation metrics.
        
        Args:
            predicted_scenes: List of predicted scene descriptions
            ground_truth_scenes: List of ground truth scene descriptions
            
        Returns:
            Dictionary with all scene evaluation metrics
        """
        # Calculate individual metric components
        temporal_metrics = self.calculate_temporal_order_accuracy(predicted_scenes, ground_truth_scenes)
        critical_metrics = self.calculate_critical_scene_detection(predicted_scenes, ground_truth_scenes)
        coherence_metrics = self.calculate_scene_coherence(predicted_scenes)
        
        # Combine all metrics
        comprehensive_metrics = {}
        comprehensive_metrics.update(temporal_metrics)
        comprehensive_metrics.update(critical_metrics)
        comprehensive_metrics.update(coherence_metrics)
        
        # Calculate overall scene quality
        scene_quality = (
            temporal_metrics['temporal_order_accuracy'] * 0.25 +
            critical_metrics['critical_scene_f1'] * 0.35 +
            critical_metrics['safety_weighted_detection'] * 0.25 +
            coherence_metrics['narrative_flow'] * 0.15
        )
        
        comprehensive_metrics['overall_scene_quality'] = float(scene_quality)
        
        return comprehensive_metrics
