#!/usr/bin/env python3
"""
LLM-based Text Normalizer for ASR evaluation.
Uses AWS Bedrock Claude models for intelligent text normalization.
"""

import json
import logging
import time
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
import re

from bedrock_claude import BedrockClaudeClient, BedrockConfig


@dataclass
class NormalizationResult:
    """Result of text normalization."""
    original_text: str
    normalized_text: str
    normalization_type: str
    success: bool
    error_message: Optional[str] = None


class LLMTextNormalizer:
    """
    LLM-based text normalizer for ASR evaluation.
    
    Provides three levels of normalization:
    1. Basic formatting (case, spaces, hyphens, punctuation)
    2. Spoken form conversion (full semantic normalization)
    3. Written form conversion (reverse normalization)
    """
    
    def __init__(self, bedrock_client: Optional[BedrockClaudeClient] = None):
        """
        Initialize the text normalizer.
        
        Args:
            bedrock_client: Optional pre-configured Bedrock client
        """
        self.client = bedrock_client or BedrockClaudeClient()
        self.logger = logging.getLogger(__name__)
        
        # Cache for normalized results to avoid repeated API calls
        self._cache: Dict[str, str] = {}
    
    def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
        """
        Call LLM with retry logic for token expiration.
        
        Args:
            prompt: The prompt to send
            max_retries: Maximum number of retries for token expiration
            
        Returns:
            LLM response content
        """
        from bedrock_claude.exceptions import BedrockClaudeError
        
        for attempt in range(max_retries + 1):
            try:
                response = self.client.send_prompt(
                    prompt=prompt,
                    temperature=0.1,  # Low temperature for consistency
                    max_tokens=1000
                )
                return response.content
                
            except BedrockClaudeError as e:
                if "Token refresh failed" in str(e) or "ExpiredTokenException" in str(e):
                    if attempt < max_retries:
                        self.logger.warning(f"Token expiration in normalizer, retrying (attempt {attempt + 1}/{max_retries})")
                        time.sleep(2 ** attempt)  # Exponential backoff
                        continue
                    else:
                        self.logger.error(f"Max retries ({max_retries}) exceeded for token expiration in normalizer")
                        raise
                else:
                    # Non-token related error, re-raise immediately
                    raise
            except Exception as e:
                if attempt < max_retries:
                    self.logger.warning(f"Normalizer error, retrying (attempt {attempt + 1}/{max_retries}): {e}")
                    time.sleep(2 ** attempt)
                    continue
                else:
                    raise
        
        # Should not reach here
        raise BedrockClaudeError("Unexpected error in LLM retry logic")
    
    def normalize_basic(self, text: str, use_cache: bool = True) -> str:
        """
        Basic text normalization - formatting only, no semantic changes.
        
        Normalizes:
        - Convert to lowercase
        - Remove punctuation (periods, commas, quotes, etc.)
        - Normalize hyphens (convert to spaces)
        - Collapse multiple spaces to single spaces
        - Strip leading/trailing whitespace
        
        Args:
            text: Input text to normalize
            use_cache: Whether to use cached results
            
        Returns:
            Normalized text string
        """
        if not text or not text.strip():
            return ""
        
        cache_key = f"basic:{text}" if use_cache else None
        if cache_key and cache_key in self._cache:
            return self._cache[cache_key]
        
        prompt = self._create_basic_normalization_prompt(text)
        
        try:
            response_content = self._call_llm_with_retry(prompt)
            normalized = self._extract_normalized_text(response_content)
            
            if cache_key:
                self._cache[cache_key] = normalized
            
            return normalized
            
        except Exception as e:
            self.logger.error(f"Basic normalization failed for text '{text}': {e}")
            # Fallback to simple regex-based normalization
            return self._fallback_basic_normalization(text)
    
    def to_spoken_form(self, text: str, use_cache: bool = True) -> str:
        """
        Convert text to spoken form for ASR evaluation.
        
        Normalizes:
        - All basic normalizations
        - Expand contractions ("don't" → "do not")
        - Convert numbers to words ("$100" → "one hundred dollars")
        - Expand abbreviations ("Dr." → "doctor")
        - Handle company names and financial terms
        
        Args:
            text: Input text to normalize
            use_cache: Whether to use cached results
            
        Returns:
            Spoken form text string
        """
        if not text or not text.strip():
            return ""
        
        cache_key = f"spoken:{text}" if use_cache else None
        if cache_key and cache_key in self._cache:
            return self._cache[cache_key]
        
        prompt = self._create_spoken_form_prompt(text)
        
        try:
            response_content = self._call_llm_with_retry(prompt)
            normalized = self._extract_normalized_text(response_content)
            
            if cache_key:
                self._cache[cache_key] = normalized
            
            return normalized
            
        except Exception as e:
            self.logger.error(f"Spoken form normalization failed for text '{text}': {e}")
            # Fallback to basic normalization
            return self.normalize_basic(text, use_cache)
    
    def to_written_form(self, text: str, use_cache: bool = True) -> str:
        """
        Convert spoken form text back to written form.
        
        Reverse normalizations:
        - Convert word numbers to digits ("one hundred" → "100")
        - Contract phrases ("do not" → "don't")
        - Convert word forms to abbreviations ("doctor" → "Dr.")
        - Add appropriate punctuation and capitalization
        
        Args:
            text: Input spoken form text to convert
            use_cache: Whether to use cached results
            
        Returns:
            Written form text string
        """
        if not text or not text.strip():
            return ""
        
        cache_key = f"written:{text}" if use_cache else None
        if cache_key and cache_key in self._cache:
            return self._cache[cache_key]
        
        prompt = self._create_written_form_prompt(text)
        
        try:
            response_content = self._call_llm_with_retry(prompt)
            normalized = self._extract_normalized_text(response_content)
            
            if cache_key:
                self._cache[cache_key] = normalized
            
            return normalized
            
        except Exception as e:
            self.logger.error(f"Written form normalization failed for text '{text}': {e}")
            # Return input text as fallback
            return text
    
    def normalize_pair(self, ground_truth: str, prediction: str, 
                      normalization_type: str = "spoken") -> tuple[str, str]:
        """
        Normalize both ground truth and prediction consistently.
        
        Args:
            ground_truth: Ground truth text
            prediction: ASR prediction text
            normalization_type: "basic", "spoken", or "written"
            
        Returns:
            Tuple of (normalized_ground_truth, normalized_prediction)
        """
        if normalization_type == "basic":
            return (
                self.normalize_basic(ground_truth),
                self.normalize_basic(prediction)
            )
        elif normalization_type == "spoken":
            return (
                self.to_spoken_form(ground_truth),
                self.to_spoken_form(prediction)
            )
        elif normalization_type == "written":
            return (
                self.to_written_form(ground_truth),
                self.to_written_form(prediction)
            )
        else:
            raise ValueError(f"Invalid normalization_type: {normalization_type}")
    
    def batch_normalize(self, texts: List[str], 
                       normalization_type: str = "spoken") -> List[str]:
        """
        Normalize a batch of texts.
        
        Args:
            texts: List of texts to normalize
            normalization_type: "basic", "spoken", or "written"
            
        Returns:
            List of normalized texts
        """
        results = []
        for text in texts:
            if normalization_type == "basic":
                results.append(self.normalize_basic(text))
            elif normalization_type == "spoken":
                results.append(self.to_spoken_form(text))
            elif normalization_type == "written":
                results.append(self.to_written_form(text))
            else:
                raise ValueError(f"Invalid normalization_type: {normalization_type}")
        
        return results
    
    def clear_cache(self) -> None:
        """Clear the normalization cache."""
        self._cache.clear()
    
    def get_cache_stats(self) -> Dict[str, int]:
        """Get cache statistics."""
        return {
            "total_entries": len(self._cache),
            "basic_entries": len([k for k in self._cache.keys() if k.startswith("basic:")]),
            "spoken_entries": len([k for k in self._cache.keys() if k.startswith("spoken:")]),
            "written_entries": len([k for k in self._cache.keys() if k.startswith("written:")])
        }
    
    def _create_basic_normalization_prompt(self, text: str) -> str:
        """Create prompt for basic normalization."""
        return f"""You are a text normalization system for ASR evaluation. Your task is to perform BASIC formatting normalization only - do NOT change the semantic meaning of words.

BASIC NORMALIZATION RULES:
1. Convert to lowercase
2. Remove ALL punctuation (periods, commas, apostrophes, quotes, colons, semicolons, etc.)
3. Convert hyphens to spaces (e.g., "interest-only" → "interest only")
4. Collapse multiple spaces into single spaces
5. Remove leading and trailing whitespace
6. DO NOT expand contractions (keep "don't" as "dont", "I've" as "ive")
7. DO NOT expand abbreviations (keep "Dr." as "dr", "TSB" as "tsb")
8. DO NOT convert numbers to words (keep "$100" as "100")

INPUT TEXT: "{text}"

Return ONLY the normalized text, nothing else:"""

    def _create_spoken_form_prompt(self, text: str) -> str:
        """Create prompt for spoken form normalization."""
        return f"""You are a text normalization system for ASR evaluation. Convert the input text to how it would be naturally spoken aloud.

SPOKEN FORM NORMALIZATION RULES:
1. Convert to lowercase
2. Remove ALL punctuation
3. Expand contractions: "don't" → "do not", "I've" → "i have", "we're" → "we are"
4. Convert numbers and currency to words: "$100" → "one hundred dollars", "3.5%" → "three point five percent"
5. Expand common abbreviations: "Dr." → "doctor", "St." → "street", "Inc." → "incorporated"
6. Convert company/brand names to spoken form: "TSB" → "t s b", "SQL" → "s q l"
7. Handle financial terms appropriately: "Basel" → "basel", "Vanguard" → "vanguard"
8. Convert hyphens to spaces: "interest-only" → "interest only"
9. Normalize whitespace (single spaces, no leading/trailing)

EXAMPLES:
- "I've been reviewing Dr. Smith's $1,000 portfolio." → "i have been reviewing doctor smiths one thousand dollars portfolio"
- "The S&P 500 is up 3.2%." → "the s and p five hundred is up three point two percent"
- "Lloyds TSB merger" → "lloyds t s b merger"

INPUT TEXT: "{text}"

Return ONLY the spoken form text, nothing else:"""

    def _create_written_form_prompt(self, text: str) -> str:
        """Create prompt for written form normalization."""
        return f"""You are a text normalization system. Convert the spoken form text back to natural written form.

WRITTEN FORM NORMALIZATION RULES:
1. Convert number words to digits: "one hundred dollars" → "$100", "three point five percent" → "3.5%"
2. Contract common phrases: "do not" → "don't", "i have" → "I've", "we are" → "we're"
3. Convert word abbreviations to standard form: "doctor" → "Dr.", "street" → "St."
4. Restore company/brand acronyms: "t s b" → "TSB", "s q l" → "SQL"
5. Add appropriate capitalization for proper nouns
6. Add appropriate punctuation (periods at end of sentences, commas in lists)
7. Handle financial terms: "basel" → "Basel", "vanguard" → "Vanguard"

EXAMPLES:
- "i have been reviewing doctor smiths one thousand dollars portfolio" → "I've been reviewing Dr. Smith's $1,000 portfolio."
- "the s and p five hundred is up three point two percent" → "The S&P 500 is up 3.2%."
- "lloyds t s b merger" → "Lloyds TSB merger"

INPUT TEXT: "{text}"

Return ONLY the written form text, nothing else:"""

    def _extract_normalized_text(self, response_content: str) -> str:
        """Extract normalized text from Claude's response."""
        # Claude should return just the normalized text, but clean it up just in case
        normalized = response_content.strip()
        
        # Remove any potential wrapper text
        if normalized.startswith('"') and normalized.endswith('"'):
            normalized = normalized[1:-1]
        
        # Remove any explanatory text that might have been added
        lines = normalized.split('\n')
        if len(lines) > 1:
            # Take the first non-empty line that looks like normalized text
            for line in lines:
                line = line.strip()
                if line and not line.startswith(('Here', 'The', 'This', 'Output:')):
                    normalized = line
                    break
        
        return normalized.strip()
    
    def _fallback_basic_normalization(self, text: str) -> str:
        """Fallback basic normalization using regex."""
        # Convert to lowercase
        result = text.lower()
        
        # Remove punctuation
        result = re.sub(r'[^\w\s-]', '', result)
        
        # Convert hyphens to spaces
        result = re.sub(r'-', ' ', result)
        
        # Collapse multiple spaces
        result = re.sub(r'\s+', ' ', result)
        
        # Strip whitespace
        result = result.strip()
        
        return result


def main():
    """Example usage of the LLM Text Normalizer."""
    # Initialize normalizer
    normalizer = LLMTextNormalizer()
    
    # Example texts from ASR results
    ground_truth = "I've been reviewing the mortgage portfolio for our high net worth clients in Kensington. The Lloyds TSB merger has created quite a kerfuffle with our interest-only tracker mortgages and the Basel accords compliance."
    
    prediction = "the lloyds tsb merger has created quite a kerfuffle with our interest only tracker mortgages and the basel accords compliance"
    
    print("=== LLM Text Normalizer Demo ===\n")
    
    print(f"Ground Truth: {ground_truth}")
    print(f"Prediction: {prediction}\n")
    
    # Basic normalization
    print("=== BASIC NORMALIZATION ===")
    gt_basic = normalizer.normalize_basic(ground_truth)
    pred_basic = normalizer.normalize_basic(prediction)
    print(f"GT Basic: {gt_basic}")
    print(f"Pred Basic: {pred_basic}\n")
    
    # Spoken form normalization
    print("=== SPOKEN FORM NORMALIZATION ===")
    gt_spoken = normalizer.to_spoken_form(ground_truth)
    pred_spoken = normalizer.to_spoken_form(prediction)
    print(f"GT Spoken: {gt_spoken}")
    print(f"Pred Spoken: {pred_spoken}\n")
    
    # Written form normalization (reverse)
    print("=== WRITTEN FORM NORMALIZATION ===")
    spoken_text = "i have been reviewing doctor smiths one thousand dollars portfolio"
    written_text = normalizer.to_written_form(spoken_text)
    print(f"Spoken: {spoken_text}")
    print(f"Written: {written_text}\n")
    
    # Batch normalization
    print("=== BATCH NORMALIZATION ===")
    texts = [
        "Dr. Smith has $1,000 in his account.",
        "The S&P 500 is up 3.2% today.",
        "I don't think we're ready for the IPO."
    ]
    
    spoken_batch = normalizer.batch_normalize(texts, "spoken")
    for original, spoken in zip(texts, spoken_batch):
        print(f"Original: {original}")
        print(f"Spoken: {spoken}\n")
    
    # Cache stats
    print("=== CACHE STATISTICS ===")
    stats = normalizer.get_cache_stats()
    print(f"Cache stats: {stats}")


if __name__ == "__main__":
    main()
