#!/usr/bin/env python3
"""
Approximate ASR Metrics Calculator - Optimized for Performance
Minimal LLM calls with slot-WER disabled and fast WER computation.
"""

import json
import logging
import re
from typing import List, Dict, Any, Optional, Union, Tuple
from dataclasses import dataclass, field
from collections import Counter

from llm_text_normalizer import LLMTextNormalizer
from bedrock_claude import BedrockClaudeClient


@dataclass
class MetricsResultApprox:
    """Lightweight ASR evaluation results with minimal LLM dependency."""
    
    # Original inputs
    original_truth: str
    original_predictions: List[str]
    
    # Normalized texts (only best prediction normalized)
    normalized_truth: str
    normalized_prediction: str  # Only best prediction
    normalized_predictions: List[str]  # For compatibility, but only first is actually normalized
    
    # Core metrics (slot-related metrics disabled for performance)
    wer: float
    ser: float  # Sentence Error Rate (0 or 1)
    
    # N-best metrics
    nbest_match: bool
    match_position: Optional[int]  # 1-indexed position of match
    nbest_size: int
    oracle_wer: float  # Best possible WER from n-best list
    
    # Simplified error analysis
    word_errors: Dict[str, int] = field(default_factory=dict)
    
    # Per-position metrics for n-best list (simplified)
    position_metrics: List[Dict[str, float]] = field(default_factory=list)
    
    # Disabled slot metrics (for compatibility)
    truth_slots: List[str] = field(default_factory=list)
    prediction_slots: List[str] = field(default_factory=list)
    slot_wer: float = 1.0  # Always 1.0 (disabled)
    oracle_slot_wer: float = 1.0  # Always 1.0 (disabled)
    slot_matches: List[Tuple[str, str]] = field(default_factory=list)
    slot_mismatches: List[Tuple[str, str]] = field(default_factory=list)
    slot_errors: Dict[str, int] = field(default_factory=dict)


class ASRMetricsApprox:
    """
    High-performance ASR metrics calculator with minimal LLM usage.
    
    Key optimizations:
    - Slot-WER computation disabled (saves ~8-10s per sample)
    - Only normalizes best prediction (saves ~80% of normalization time)
    - Fast approximate WER computation
    - Minimal error handling overhead
    """
    
    def __init__(self, normalizer: Optional[LLMTextNormalizer] = None, 
                 bedrock_client: Optional['BedrockClaudeClient'] = None,
                 enable_slot_wer: bool = False):
        """
        Initialize approximate ASR metrics calculator.
        
        Args:
            normalizer: Optional pre-configured text normalizer
            bedrock_client: Optional shared Bedrock client
            enable_slot_wer: If True, enable slot-WER computation (slower)
        """
        if bedrock_client and normalizer:
            self.normalizer = normalizer
        elif bedrock_client:
            self.normalizer = LLMTextNormalizer(bedrock_client)
        else:
            self.normalizer = normalizer or LLMTextNormalizer()
        
        self.enable_slot_wer = enable_slot_wer
        self.logger = logging.getLogger(__name__)
        
        # Minimal cache for normalization (only for best predictions)
        self._norm_cache: Dict[str, str] = {}
    
    def compute_metrics(self, ground_truth: str, 
                       predictions: Union[str, List[str]],
                       skip_norm_for_truth: bool = False) -> MetricsResultApprox:
        """
        Compute approximate ASR metrics with minimal LLM calls.
        
        Args:
            ground_truth: Reference text
            predictions: Single prediction or n-best list
            skip_norm_for_truth: If True, skip normalization for ground_truth
            
        Returns:
            MetricsResultApprox with computed metrics
        """
        # Ensure predictions is a list
        if isinstance(predictions, str):
            predictions = [predictions]
        
        # Normalize ground truth (conditionally)
        if skip_norm_for_truth:
            normalized_truth = ground_truth
        else:
            normalized_truth = self._normalize_with_cache(ground_truth)
        
        # Only normalize the BEST prediction (huge performance gain)
        best_prediction = predictions[0] if predictions else ""
        normalized_best_prediction = self._normalize_with_cache(best_prediction)
        
        # Create normalized_predictions list for compatibility
        # Only first prediction is actually normalized, rest are approximated
        normalized_predictions = [normalized_best_prediction]
        for pred in predictions[1:]:
            # For n-best positions 2+, use fast approximation instead of LLM
            normalized_predictions.append(self._fast_normalize_approximation(pred))
        
        # Compute core metrics on best prediction
        wer = self._compute_wer_fast(normalized_truth, normalized_best_prediction)
        ser = 1.0 if wer > 0 else 0.0
        
        # Compute n-best metrics using fast approximation
        nbest_match, match_position = self._find_nbest_match_fast(
            normalized_truth, normalized_predictions
        )
        
        # Compute oracle metrics using fast approximation
        oracle_wer = self._compute_oracle_wer_fast(normalized_truth, normalized_predictions)
        
        # Simplified error analysis
        word_errors = self._analyze_word_errors_fast(normalized_truth, normalized_best_prediction)
        
        # Per-position metrics (simplified)
        position_metrics = []
        for i, pred in enumerate(normalized_predictions):
            pos_wer = self._compute_wer_fast(normalized_truth, pred)
            position_metrics.append({
                'position': i + 1,
                'wer': pos_wer,
                'slot_wer': 1.0,  # Disabled
                'ser': 1.0 if pos_wer > 0 else 0.0
            })
        
        # Slot metrics (disabled for performance, but included for compatibility)
        truth_slots = []
        prediction_slots = []
        slot_wer = 1.0
        oracle_slot_wer = 1.0
        slot_matches = []
        slot_mismatches = []
        slot_errors = {}
        
        # If slot-WER is explicitly enabled (slower path)
        if self.enable_slot_wer:
            try:
                truth_slots = self._extract_noun_slots_fast(normalized_truth)
                prediction_slots = self._extract_noun_slots_fast(normalized_best_prediction)
                slot_wer = self._compute_slot_wer_fast(truth_slots, prediction_slots)
                oracle_slot_wer = self._compute_oracle_slot_wer_fast(truth_slots, normalized_predictions)
                slot_matches, slot_mismatches = self._analyze_slot_matches_fast(truth_slots, prediction_slots)
                slot_errors = self._analyze_slot_errors_fast(truth_slots, prediction_slots)
            except Exception as e:
                self.logger.warning(f"Slot-WER computation failed: {e}")
        
        return MetricsResultApprox(
            original_truth=ground_truth,
            original_predictions=predictions,
            normalized_truth=normalized_truth,
            normalized_prediction=normalized_best_prediction,
            normalized_predictions=normalized_predictions,
            wer=wer,
            ser=ser,
            nbest_match=nbest_match,
            match_position=match_position,
            nbest_size=len(predictions),
            oracle_wer=oracle_wer,
            word_errors=word_errors,
            position_metrics=position_metrics,
            truth_slots=truth_slots,
            prediction_slots=prediction_slots,
            slot_wer=slot_wer,
            oracle_slot_wer=oracle_slot_wer,
            slot_matches=slot_matches,
            slot_mismatches=slot_mismatches,
            slot_errors=slot_errors
        )
    
    def _normalize_with_cache(self, text: str) -> str:
        """Normalize text with caching to avoid duplicate LLM calls."""
        if text in self._norm_cache:
            return self._norm_cache[text]
        
        try:
            normalized = self.normalizer.to_written_form(text)
            self._norm_cache[text] = normalized
            return normalized
        except Exception as e:
            self.logger.warning(f"Normalization failed for '{text}': {e}")
            return text.lower().strip()
    
    def _fast_normalize_approximation(self, text: str) -> str:
        """
        Fast approximation of normalization without LLM calls.
        Used for n-best positions 2+ to avoid expensive LLM calls.
        """
        # Simple rule-based approximation
        text = text.lower().strip()
        
        # Basic number normalization
        text = re.sub(r'\$(\d+(?:\.\d+)?)', r'\1 dollars', text)
        text = re.sub(r'(\d+)%', r'\1 percent', text)
        text = re.sub(r'(\d+)k\b', r'\1 thousand', text, flags=re.IGNORECASE)
        text = re.sub(r'(\d+)m\b', r'\1 million', text, flags=re.IGNORECASE)
        text = re.sub(r'(\d+)b\b', r'\1 billion', text, flags=re.IGNORECASE)
        
        # Basic contractions
        text = re.sub(r"won't", "will not", text)
        text = re.sub(r"can't", "cannot", text)
        text = re.sub(r"n't", " not", text)
        text = re.sub(r"'re", " are", text)
        text = re.sub(r"'ve", " have", text)
        text = re.sub(r"'ll", " will", text)
        text = re.sub(r"'d", " would", text)
        
        # Clean up spaces
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def _clean_text_for_wer(self, text: str) -> str:
        """Clean text for WER computation - fast version."""
        text = text.lower()
        # Remove punctuation
        text = re.sub(r'[.,!?;:"\'`()\[\]{}\-_/\\|@#$%^&*+=~<>]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text
    
    def _compute_wer_fast(self, reference: str, hypothesis: str) -> float:
        """Fast WER computation with minimal overhead."""
        ref_cleaned = self._clean_text_for_wer(reference)
        hyp_cleaned = self._clean_text_for_wer(hypothesis)
        
        ref_words = ref_cleaned.split()
        hyp_words = hyp_cleaned.split()
        
        if len(ref_words) == 0:
            return 1.0 if len(hyp_words) > 0 else 0.0
        
        # Fast edit distance computation
        d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)]
        
        for i in range(len(ref_words) + 1):
            d[i][0] = i
        for j in range(len(hyp_words) + 1):
            d[0][j] = j
        
        for i in range(1, len(ref_words) + 1):
            for j in range(1, len(hyp_words) + 1):
                if ref_words[i-1] == hyp_words[j-1]:
                    d[i][j] = d[i-1][j-1]
                else:
                    d[i][j] = min(d[i-1][j] + 1, d[i][j-1] + 1, d[i-1][j-1] + 1)
        
        return d[len(ref_words)][len(hyp_words)] / len(ref_words)
    
    def _find_nbest_match_fast(self, reference: str, 
                              predictions: List[str]) -> Tuple[bool, Optional[int]]:
        """Fast n-best match finding."""
        reference_cleaned = self._clean_text_for_wer(reference)
        for i, prediction in enumerate(predictions):
            prediction_cleaned = self._clean_text_for_wer(prediction)
            if reference_cleaned == prediction_cleaned:
                return True, i + 1
        return False, None
    
    def _compute_oracle_wer_fast(self, reference: str, predictions: List[str]) -> float:
        """Fast oracle WER computation."""
        if not predictions:
            return 1.0
        wers = [self._compute_wer_fast(reference, pred) for pred in predictions]
        return min(wers)
    
    def _analyze_word_errors_fast(self, reference: str, hypothesis: str) -> Dict[str, int]:
        """Fast word error analysis."""
        ref_words = reference.split()
        hyp_words = hypothesis.split()
        
        ref_count = len(ref_words)
        hyp_count = len(hyp_words)
        
        # Simple approximation
        if ref_count == hyp_count:
            # Assume mostly substitutions
            substitutions = sum(1 for r, h in zip(ref_words, hyp_words) if r != h)
            return {'substitutions': substitutions, 'insertions': 0, 'deletions': 0, 'total_errors': substitutions}
        elif hyp_count > ref_count:
            # More hypothesis words - assume insertions
            insertions = hyp_count - ref_count
            substitutions = max(0, ref_count - (hyp_count - insertions))
            return {'substitutions': substitutions, 'insertions': insertions, 'deletions': 0, 'total_errors': substitutions + insertions}
        else:
            # Fewer hypothesis words - assume deletions
            deletions = ref_count - hyp_count
            substitutions = max(0, hyp_count - (ref_count - deletions))
            return {'substitutions': substitutions, 'insertions': 0, 'deletions': deletions, 'total_errors': substitutions + deletions}
    
    # Slot-related methods (only used if enable_slot_wer=True)
    def _extract_noun_slots_fast(self, text: str) -> List[str]:
        """Fast noun slot extraction - simplified version."""
        if not self.enable_slot_wer:
            return []
        
        # Simple rule-based extraction instead of LLM
        words = text.lower().split()
        slots = []
        
        # Look for capitalized words (proper nouns) and financial terms
        financial_terms = {'revenue', 'profit', 'loss', 'mortgage', 'loan', 'bond', 'stock', 'share', 'dividend'}
        tech_terms = {'database', 'server', 'api', 'software', 'system', 'platform'}
        
        for word in words:
            word_clean = re.sub(r'[^\w]', '', word)
            if word_clean in financial_terms or word_clean in tech_terms:
                slots.append(word_clean)
        
        return list(set(slots))  # Remove duplicates
    
    def _compute_slot_wer_fast(self, ref_slots: List[str], hyp_slots: List[str]) -> float:
        """Fast slot WER computation."""
        if not self.enable_slot_wer:
            return 1.0
        
        if len(ref_slots) == 0:
            return 1.0 if len(hyp_slots) > 0 else 0.0
        
        ref_text = ' '.join(ref_slots)
        hyp_text = ' '.join(hyp_slots)
        return self._compute_wer_fast(ref_text, hyp_text)
    
    def _compute_oracle_slot_wer_fast(self, ref_slots: List[str], 
                                     predictions: List[str]) -> float:
        """Fast oracle slot WER computation."""
        if not self.enable_slot_wer:
            return 1.0
        
        if not predictions:
            return 1.0
        
        slot_wers = []
        for pred in predictions:
            pred_slots = self._extract_noun_slots_fast(pred)
            slot_wer = self._compute_slot_wer_fast(ref_slots, pred_slots)
            slot_wers.append(slot_wer)
        
        return min(slot_wers)
    
    def _analyze_slot_matches_fast(self, ref_slots: List[str], 
                                  hyp_slots: List[str]) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
        """Fast slot match analysis."""
        if not self.enable_slot_wer:
            return [], []
        
        ref_set = set(ref_slots)
        hyp_set = set(hyp_slots)
        
        matches = [(slot, slot) for slot in ref_set & hyp_set]
        mismatches = [(slot, "MISSING") for slot in ref_set - hyp_set] + \
                    [("EXTRA", slot) for slot in hyp_set - ref_set]
        
        return matches, mismatches
    
    def _analyze_slot_errors_fast(self, ref_slots: List[str], 
                                 hyp_slots: List[str]) -> Dict[str, int]:
        """Fast slot error analysis."""
        if not self.enable_slot_wer:
            return {}
        
        ref_count = len(ref_slots)
        hyp_count = len(hyp_slots)
        
        if ref_count == hyp_count:
            substitutions = len(set(ref_slots) - set(hyp_slots))
            return {'substitutions': substitutions, 'insertions': 0, 'deletions': 0, 'total_errors': substitutions}
        elif hyp_count > ref_count:
            insertions = hyp_count - ref_count
            return {'substitutions': 0, 'insertions': insertions, 'deletions': 0, 'total_errors': insertions}
        else:
            deletions = ref_count - hyp_count
            return {'substitutions': 0, 'insertions': 0, 'deletions': deletions, 'total_errors': deletions}
    
    def clear_cache(self) -> None:
        """Clear the normalization cache."""
        self._norm_cache.clear()
    
    def get_cache_stats(self) -> Dict[str, int]:
        """Get cache statistics."""
        return {
            'norm_cache_entries': len(self._norm_cache)
        }


def main():
    """Example usage of ASRMetricsApprox."""
    from llm_text_normalizer import LLMTextNormalizer
    
    # Initialize with slot-WER disabled for maximum performance
    normalizer = LLMTextNormalizer()
    metrics = ASRMetricsApprox(normalizer, enable_slot_wer=False)
    
    # Example evaluation
    ground_truth = "The Lloyds TSB merger has created quite a kerfuffle with our interest-only tracker mortgages."
    predictions = [
        "the lloyds tsb merger has created quite a kerffle with our interest only tracker mortgages",
        "the lloyds tsb merger has created quite a kerfuffle with our interest-only tracker mortgages",
        "the lloyd's tsb merger has created quite a kerfuffle with our interest only tracker mortgages"
    ]
    
    print("=== ASR Metrics Evaluation (Approximate/Fast) ===\n")
    
    # Compute metrics
    result = metrics.compute_metrics(ground_truth, predictions)
    
    print(f"Ground Truth: {result.original_truth}")
    print(f"Best Prediction: {result.original_predictions[0]}")
    print(f"N-best size: {result.nbest_size}\n")
    
    print("=== CORE METRICS ===")
    print(f"WER: {result.wer:.4f}")
    print(f"SER: {result.ser:.4f}")
    print(f"SlotWER: {result.slot_wer:.4f} (disabled for performance)")
    
    print("\n=== N-BEST METRICS ===")
    print(f"N-best match: {result.nbest_match}")
    print(f"Match position: {result.match_position}")
    print(f"Oracle WER: {result.oracle_wer:.4f}")
    
    print(f"\n=== PERFORMANCE INFO ===")
    print(f"Cache stats: {metrics.get_cache_stats()}")
    print("Slot-WER disabled for ~3x performance improvement")


if __name__ == "__main__":
    main()
