#!/usr/bin/env python3
"""
Comprehensive ASR Evaluation Pipeline - Oracle Prompt Testing
Tests ASR upper bound performance using the correct answer (normalized_truth) as the prompt.
This creates an oracle experiment to measure maximum possible ASR performance with perfect context.
"""

import torch
import numpy as np
import json
import time
import argparse
import sys
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional

# Core imports
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_from_disk
import librosa

# Local imports - use improved versions
from compute_metrics_approx import ASRMetricsApprox, MetricsResultApprox
from asr_experiment_logger_improved import ASRExperimentLoggerImproved
from llm_text_normalizer import LLMTextNormalizer



class ComprehensiveASREvaluatorOracle:
    """
    Oracle ASR evaluation pipeline testing upper bound performance.
    
    Tests ASR maximum potential by using the correct answer (normalized_truth) as the prompt:
    - Provides the ground truth as context to the ASR system
    - Measures the upper bound of ASR performance with perfect prompting
    - Evaluates how much ASR can improve when given ideal context
    
    This creates an oracle experiment to understand the ceiling of ASR performance.
    """
    
    def __init__(self, 
                 model_name: str = "openai/whisper-small",
                 dataset_path: str = "../final_hf_dataset/agentic_asr_normalized",
                 experiment_name: str = "whisper-small-oracle",
                 num_beams: int = 1,
                 checkpoint_interval: int = 500,
                 enable_slot_wer: bool = False,
                 auto_resume: bool = True):
        
        self.model_name = model_name
        self.dataset_path = dataset_path
        self.experiment_name = experiment_name
        self.num_beams = num_beams
        self.checkpoint_interval = checkpoint_interval
        self.enable_slot_wer = enable_slot_wer
        self.auto_resume = auto_resume
        
        # Initialize components
        self.model = None
        self.processor = None
        self.dataset = None
        self.logger = None
        self.metrics_calculator = None
        
        # Tracking variables
        self.processed_count = 0
        self.start_time = None
        self.results = []
        self.resume_index = 0
        
        # Performance tracking with oracle-specific metrics
        self.performance_stats = {
            "total_samples": 0,
            "successful_samples": 0,
            "failed_samples": 0,
            "avg_processing_time": 0.0,
            "oracle_testing": True,
            "slot_wer_disabled": not enable_slot_wer,
            "oracle_prompts_used": 0,
            "perfect_context_provided": True
        }
        
    def get_oracle_prompt(self, normalized_truth: str) -> str:
        """
        Generate oracle prompt using the normalized ground truth.
        
        Args:
            normalized_truth: The correct normalized transcription
            
        Returns:
            The oracle prompt (normalized_truth)
        """
        # Track usage
        self.performance_stats["oracle_prompts_used"] += 1
        
        # Return the normalized truth as the oracle prompt
        return normalized_truth
        
    def initialize_components(self):
        """Initialize all pipeline components with oracle testing setup."""
        print("🚀 Initializing Oracle ASR Evaluation Pipeline...")
        print("🔮 ORACLE MODE: Using ground truth as prompts for upper bound testing")
        
        # Load model and processor
        print(f"Loading Whisper model: {self.model_name}")
        self.processor = WhisperProcessor.from_pretrained(self.model_name)
        self.model = WhisperForConditionalGeneration.from_pretrained(
            self.model_name,
            torch_dtype=torch.float32,
            device_map="auto"
        )
        print("✅ Model loaded successfully")
        
        # Load dataset
        print(f"Loading dataset from: {self.dataset_path}")
        self.dataset = load_from_disk(self.dataset_path)
        print(f"✅ Dataset loaded: {len(self.dataset)} samples")
        
        # Display oracle strategy
        print("\n🔮 Oracle Testing Strategy:")
        print("   Using normalized_truth as prompt to measure ASR upper bound performance")
        print("   This provides perfect context to test maximum possible accuracy")
        
        # Initialize shared Bedrock client for consistent session management
        print("\nInitializing shared Bedrock client...")
        from bedrock_claude import BedrockClaudeClient
        shared_bedrock_client = BedrockClaudeClient()
        print("✅ Shared Bedrock client initialized")
        
        # Initialize optimized metrics calculator
        print("Initializing OPTIMIZED ASR metrics calculator...")
        normalizer = LLMTextNormalizer(shared_bedrock_client)
        self.metrics_calculator = ASRMetricsApprox(
            normalizer, 
            shared_bedrock_client, 
            enable_slot_wer=self.enable_slot_wer
        )
        if self.enable_slot_wer:
            print("⚠️  Slot-WER enabled (slower performance)")
        else:
            print("🚀 Slot-WER disabled for 3x performance improvement")
        print("✅ Optimized metrics calculator initialized")
        
        # Initialize enhanced experiment logger
        print(f"Initializing enhanced experiment logger: {self.experiment_name}")
        self.logger = ASRExperimentLoggerImproved(
            experiment_id=self.experiment_name,
            output_format="jsonl",
            output_dir="asr_experiments",
            checkpoint_interval=self.checkpoint_interval
        )
        print("✅ Enhanced logger initialized")
        
        # Auto-resume detection
        if self.auto_resume:
            print("🔍 Checking for existing oracle experiments to resume...")
            self.resume_index = self.logger.detect_resume_point(self.experiment_name)
            if self.resume_index > 0:
                print(f"🔄 Auto-resume enabled: starting from sample {self.resume_index}")
            else:
                print("🆕 No existing oracle experiments found - starting fresh")
        
    def process_audio_with_oracle_prompt(self, 
                                        audio_array: np.ndarray, 
                                        sample_rate: int, 
                                        oracle_prompt: str) -> Tuple[List[str], List[float]]:
        """Process audio with Whisper using oracle prompt conditioning."""
        
        # Resample if needed (Whisper expects 16kHz)
        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            audio_array = librosa.resample(
                audio_array, 
                orig_sr=sample_rate, 
                target_sr=target_sample_rate
            )
            sample_rate = target_sample_rate
        
        # Process audio inputs
        inputs = self.processor(
            audio_array,
            sampling_rate=sample_rate,
            return_tensors="pt"
        )
        
        # Move to device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Prepare oracle prompt for Whisper's native conditioning
        generation_kwargs = {
            "input_features": inputs["input_features"],
            "max_new_tokens": 256,
            "do_sample": False,
            "num_beams": 1,  # Greedy decoding for speed
            "pad_token_id": self.processor.tokenizer.pad_token_id,
            "eos_token_id": self.processor.tokenizer.eos_token_id,
        }
        
        # Add oracle prompt conditioning
        if oracle_prompt and oracle_prompt.strip():
            try:
                prompt_ids = self.processor.tokenizer(
                    oracle_prompt,
                    return_tensors="pt",
                    add_special_tokens=False
                ).input_ids.to(self.model.device)
                
                # Only add if prompt_ids is valid
                if prompt_ids is not None and prompt_ids.numel() > 0:
                    generation_kwargs["prompt_ids"] = prompt_ids
            except Exception as e:
                print(f"  ⚠️ Oracle prompt processing error: {e}, continuing without prompt")
        
        # Generate transcription with oracle conditioning
        with torch.no_grad():
            try:
                generated_outputs = self.model.generate(**generation_kwargs)
            except Exception as e:
                print(f"  ⚠️ Generation error: {e}")
                # Simple fallback without any special parameters
                try:
                    generated_outputs = self.model.generate(
                        inputs["input_features"],
                        max_new_tokens=256
                    )
                except Exception as fallback_e:
                    print(f"  ❌ Fallback generation failed: {fallback_e}")
                    return [""], [0.0]
        
        # Decode the generated sequence
        if generated_outputs is not None and generated_outputs.numel() > 0:
            # Take only the first sequence (greedy decoding)
            if len(generated_outputs.shape) > 1:
                sequence = generated_outputs[0]
            else:
                sequence = generated_outputs
            
            # Decode to text
            transcription = self.processor.decode(sequence, skip_special_tokens=True)
            transcription = transcription.strip()
            
            # Return single transcription with confidence 1.0 (greedy decoding)
            return [transcription], [1.0]
        else:
            return [""], [0.0]
    
    def process_single_sample(self, sample: Dict[str, Any], sample_idx: int) -> Dict[str, Any]:
        """Process a single sample through the oracle pipeline."""
        
        # Extract sample data
        utterance_id = sample.get('utterance_id', f'sample_{sample_idx}')
        actual_domain = sample.get('domain', 'unknown')
        voice = sample.get('voice', 'unknown')
        asr_difficulty = sample.get('asr_difficulty', 0.0)
        
        # Ground truth data
        truth = sample.get('truth', sample.get('text', ''))
        normalized_truth = sample.get('normalized_truth', truth)
        
        # Get original prompt for comparison
        original_prompt_field = sample.get('prompt', [])
        if isinstance(original_prompt_field, list) and len(original_prompt_field) > 0:
            original_prompt = ' '.join(original_prompt_field)
        elif isinstance(original_prompt_field, str):
            original_prompt = original_prompt_field
        else:
            original_prompt = ''
        
        # Generate oracle prompt using normalized truth
        oracle_prompt = self.get_oracle_prompt(normalized_truth)
        
        # Audio data
        audio_data = sample['audio']
        audio_array = np.array(audio_data['array'])
        sample_rate = audio_data['sampling_rate']
        
        print(f"  📝 Processing: {utterance_id} ({actual_domain}/{voice})")
        print(f"  🔮 Oracle: Using ground truth as prompt")
        
        # Process audio with oracle prompt
        start_time = time.time()
        try:
            transcriptions, probabilities = self.process_audio_with_oracle_prompt(
                audio_array, sample_rate, oracle_prompt
            )
            processing_time = time.time() - start_time
            
            # Get best transcription
            best_transcription = transcriptions[0] if transcriptions else ""
            
            print(f"    🎯 Result: {best_transcription[:100]}...")
            
        except Exception as e:
            print(f"    ❌ ASR Error: {e}")
            transcriptions = [""]
            probabilities = [0.0]
            best_transcription = ""
            processing_time = 0.0
            
            # Check for critical errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise  # Re-raise critical errors
        
        # Compute optimized metrics
        try:
            # Check if normalized_truth is actually normalized
            if normalized_truth != truth:
                # Use pre-normalized truth
                metrics_result = self.metrics_calculator.compute_metrics(
                    ground_truth=normalized_truth,
                    predictions=transcriptions,
                    skip_norm_for_truth=True
                )
            else:
                # normalized_truth is same as truth, so normalize both
                metrics_result = self.metrics_calculator.compute_metrics(
                    ground_truth=truth,
                    predictions=transcriptions,
                    skip_norm_for_truth=False  # Normalize both
                )
            
            slot_status = "disabled" if not self.enable_slot_wer else f"{metrics_result.slot_wer:.3f}"
            print(f"    📊 WER: {metrics_result.wer:.3f}, SER: {metrics_result.ser:.3f}, SlotWER: {slot_status}")
            
        except Exception as e:
            print(f"    ❌ Metrics Error: {e}")
            
            # Check for critical errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise  # Re-raise critical errors
            
            # Create a mock MetricsResultApprox-like object
            class MockMetrics:
                def __init__(self):
                    self.wer = 1.0
                    self.ser = 1.0
                    self.slot_wer = 1.0
                    self.oracle_wer = 1.0
                    self.oracle_slot_wer = 1.0
                    self.normalized_prediction = best_transcription
                    self.normalized_truth = normalized_truth
                    self.normalized_predictions = [best_transcription]
                    self.nbest_match = False
                    self.truth_slots = []
                    self.prediction_slots = []
                    self.slot_matches = []
                    self.slot_mismatches = []
                    self.word_errors = {}
                    self.slot_errors = {}
                    self.position_metrics = []
            
            metrics_result = MockMetrics()
        
        # Compile complete result with oracle-specific fields
        result = {
            'sample_index': sample_idx,
            'utterance_id': utterance_id,
            'domain': actual_domain,
            'voice': voice,
            'asr_difficulty': asr_difficulty,
            'original_prompt': original_prompt,
            'oracle_prompt': oracle_prompt,
            'ground_truth': truth,
            'normalized_truth': normalized_truth,
            'transcriptions': transcriptions,
            'probabilities': probabilities,
            'best_transcription': best_transcription,
            'processing_time': processing_time,
            'metrics': metrics_result,
            'timestamp': datetime.now().isoformat(),
            'oracle_testing': True,
            'slot_wer_enabled': self.enable_slot_wer
        }
        
        return result
    
    def save_checkpoint(self, results: List[Dict[str, Any]], checkpoint_num: int):
        """Save checkpoint results using enhanced logger."""
        self.logger.save_checkpoint(results, checkpoint_num)
    
    def consolidate_final_results(self):
        """Consolidate all results and generate oracle-specific analytics."""
        print("\n🔄 Consolidating oracle evaluation results...")
        
        # Create final directory
        final_dir = Path(f"asr_experiments/{self.logger.experiment_id}")
        final_dir.mkdir(parents=True, exist_ok=True)
        
        # Save final JSONL
        final_jsonl = final_dir / f"{self.logger.experiment_id}_FINAL.jsonl"
        with open(final_jsonl, 'w') as f:
            for result in self.results:
                f.write(json.dumps(result, default=str) + '\n')
        
        # Generate final CSV
        final_csv = final_dir / f"{self.logger.experiment_id}_FINAL.csv"
        self.generate_oracle_csv_summary(final_csv)
        
        # Generate oracle analytics
        self.generate_oracle_analytics()
        
        print(f"✅ Oracle evaluation results saved to: {final_dir}")
    
    def generate_oracle_csv_summary(self, csv_path: Path):
        """Generate CSV summary with oracle-specific columns."""
        import csv
        
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            
            # Header with oracle fields
            writer.writerow([
                'utterance_id', 'domain', 'voice', 'asr_difficulty',
                'ground_truth', 'best_transcription', 'wer', 'ser', 'slot_wer',
                'processing_time', 'confidence_score', 'original_prompt', 'oracle_prompt'
            ])
            
            # Data rows
            for result in self.results:
                metrics = result['metrics']
                writer.writerow([
                    result['utterance_id'],
                    result['domain'],
                    result['voice'],
                    result['asr_difficulty'],
                    result['ground_truth'],
                    result['best_transcription'],
                    metrics.wer if hasattr(metrics, 'wer') else getattr(metrics, 'wer', 1.0),
                    metrics.ser if hasattr(metrics, 'ser') else getattr(metrics, 'ser', 1.0),
                    metrics.slot_wer if hasattr(metrics, 'slot_wer') else getattr(metrics, 'slot_wer', 1.0),
                    result['processing_time'],
                    result['probabilities'][0] if result['probabilities'] else 0.0,
                    result['original_prompt'],
                    result['oracle_prompt']
                ])
    
    def generate_oracle_analytics(self):
        """Generate comprehensive oracle-specific analytics."""
        analytics_dir = Path(f"asr_experiments/{self.logger.experiment_id}/final_analytics")
        analytics_dir.mkdir(parents=True, exist_ok=True)
        
        # Overall oracle performance
        wers = [getattr(r['metrics'], 'wer', 1.0) for r in self.results]
        sers = [getattr(r['metrics'], 'ser', 1.0) for r in self.results]
        slot_wers = [getattr(r['metrics'], 'slot_wer', 1.0) for r in self.results]
        processing_times = [r['processing_time'] for r in self.results]
        
        overall_stats = {
            'experiment_type': 'oracle_upper_bound_testing',
            'total_samples': len(self.results),
            'avg_wer': np.mean(wers),
            'avg_ser': np.mean(sers),
            'avg_slot_wer': np.mean(slot_wers),
            'std_wer': np.std(wers),
            'std_ser': np.std(sers),
            'std_slot_wer': np.std(slot_wers),
            'total_runtime': time.time() - self.start_time if self.start_time else 0,
            'avg_processing_time': np.mean(processing_times),
            'min_processing_time': np.min(processing_times),
            'max_processing_time': np.max(processing_times),
            'oracle_testing': True,
            'slot_wer_enabled': self.enable_slot_wer,
            'checkpoint_interval': self.checkpoint_interval,
            'auto_resume_enabled': self.auto_resume,
            'resume_index': self.resume_index,
            'oracle_prompts_used': self.performance_stats["oracle_prompts_used"],
            'perfect_context_provided': self.performance_stats["perfect_context_provided"]
        }
        
        with open(analytics_dir / 'overall_oracle_performance.json', 'w') as f:
            json.dump(overall_stats, f, indent=2, default=str)
        
        # Domain breakdown with oracle performance
        domain_stats = {}
        for domain in set(r['domain'] for r in self.results):
            domain_results = [r for r in self.results if r['domain'] == domain]
            domain_wers = [getattr(r['metrics'], 'wer', 1.0) for r in domain_results]
            domain_times = [r['processing_time'] for r in domain_results]
            
            domain_stats[domain] = {
                'count': len(domain_results),
                'avg_wer': np.mean(domain_wers),
                'avg_ser': np.mean([getattr(r['metrics'], 'ser', 1.0) for r in domain_results]),
                'avg_slot_wer': np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in domain_results]),
                'avg_processing_time': np.mean(domain_times),
                'oracle_testing': True
            }
        
        with open(analytics_dir / 'oracle_domain_breakdown.json', 'w') as f:
            json.dump(domain_stats, f, indent=2, default=str)
        
        # Voice breakdown with oracle performance
        voice_stats = {}
        for voice in set(r['voice'] for r in self.results):
            voice_results = [r for r in self.results if r['voice'] == voice]
            voice_wers = [getattr(r['metrics'], 'wer', 1.0) for r in voice_results]
            voice_times = [r['processing_time'] for r in voice_results]
            voice_stats[voice] = {
                'count': len(voice_results),
                'avg_wer': np.mean(voice_wers),
                'avg_ser': np.mean([getattr(r['metrics'], 'ser', 1.0) for r in voice_results]),
                'avg_slot_wer': np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in voice_results]),
                'avg_processing_time': np.mean(voice_times),
                'oracle_testing': True
            }
        
        with open(analytics_dir / 'oracle_voice_breakdown.json', 'w') as f:
            json.dump(voice_stats, f, indent=2, default=str)
        
        # Oracle vs baseline comparison potential
        oracle_comparison = {
            'oracle_experiment': True,
            'perfect_context_provided': True,
            'upper_bound_performance': {
                'avg_wer': np.mean(wers),
                'min_wer': np.min(wers),
                'max_wer': np.max(wers),
                'perfect_transcriptions': len([w for w in wers if w == 0.0]),
                'near_perfect_transcriptions': len([w for w in wers if w <= 0.1])
            },
            'analysis_notes': [
                'This oracle experiment provides upper bound ASR performance',
                'Results show maximum possible accuracy with perfect prompting',
                'Compare with baseline/adversarial experiments to measure prompt impact'
            ]
        }
        
        with open(analytics_dir / 'oracle_upper_bound_analysis.json', 'w') as f:
            json.dump(oracle_comparison, f, indent=2, default=str)
        
        print("📊 Oracle analytics generated with upper bound performance analysis")
    
    def run_evaluation(self, start_idx: int = 0, max_samples: Optional[int] = None):
        """Run the complete oracle evaluation pipeline with auto-resume."""
        self.start_time = time.time()
        
        # Initialize components
        self.initialize_components()
        
        # Use resume index if auto-resume is enabled
        if self.auto_resume and self.resume_index > start_idx:
            start_idx = self.resume_index
            print(f"🔄 Resuming oracle evaluation from index {start_idx}")
        
        # Determine sample range
        total_samples = len(self.dataset)
        end_idx = min(start_idx + max_samples, total_samples) if max_samples else total_samples
        
        print(f"\n🎯 Starting ORACLE evaluation: samples {start_idx} to {end_idx-1} ({end_idx-start_idx} total)")
        print(f"🔮 Oracle strategy: Using normalized_truth as prompt for upper bound testing")
        print(f"🚀 Performance optimizations enabled:")
        print(f"   - Slot-WER: {'ENABLED (slower)' if self.enable_slot_wer else 'DISABLED (3x faster)'}")
        print(f"   - Checkpoint interval: {self.checkpoint_interval} samples")
        print(f"   - Auto-resume: {'ENABLED' if self.auto_resume else 'DISABLED'}")
        
        # Process samples
        checkpoint_results = []
        
        for i in range(start_idx, end_idx):
            sample = self.dataset[i]
            
            try:
                # Check for critical errors before processing
                if self.logger.should_abort():
                    print(f"\n🛑 Critical error threshold reached - aborting at sample {i}")
                    print("💾 Current state saved for resume")
                    break
                
                # Log individual experiment with proper context
                with self.logger.start_experiment(sample, i) as exp:
                    # Process the sample with oracle prompts
                    result = self.process_single_sample(sample, i)
                    self.results.append(result)
                    checkpoint_results.append(result)
                    self.processed_count += 1
                    
                    # Update performance stats
                    self.performance_stats["total_samples"] += 1
                    if not result.get('errors', []):
                        self.performance_stats["successful_samples"] += 1
                    else:
                        self.performance_stats["failed_samples"] += 1
                    
                    # Update average processing time
                    if result['processing_time'] > 0:
                        current_avg = self.performance_stats["avg_processing_time"]
                        count = self.performance_stats["total_samples"]
                        new_avg = (current_avg * (count - 1) + result['processing_time']) / count
                        self.performance_stats["avg_processing_time"] = new_avg
                    
                    # Log ASR step with oracle context
                    exp.log_asr_step(
                        model_name=self.model_name,
                        model_params={
                            'num_beams': self.num_beams,
                            'max_new_tokens': 256,
                            'temperature': 1.0,
                            'oracle_testing': True,
                            'oracle_prompt': result['oracle_prompt']
                        },
                        predictions=result['transcriptions'],
                        confidence_scores=result['probabilities'],
                        processing_time=result['processing_time']
                    )
                    
                    # Log normalization step
                    exp.log_normalization_step(
                        original_truth=result['ground_truth'],
                        normalized_truth=result['normalized_truth']
                    )
                    
                    # Log metrics step with the actual MetricsResult object
                    exp.log_metrics_step(result['metrics'])
                
                # Checkpoint save
                if (i + 1) % self.checkpoint_interval == 0:
                    checkpoint_num = (i + 1) // self.checkpoint_interval
                    self.save_checkpoint(checkpoint_results, checkpoint_num)
                    checkpoint_results = []  # Reset for next checkpoint
                    
                    # Progress update with oracle performance info
                    elapsed = time.time() - self.start_time
                    rate = self.processed_count / elapsed
                    remaining = (end_idx - i - 1) / rate if rate > 0 else 0
                    avg_time = self.performance_stats["avg_processing_time"]
                    oracle_count = self.performance_stats["oracle_prompts_used"]
                    print(f"⏱️  Progress: {i+1}/{end_idx} ({(i+1)/end_idx*100:.1f}%) - {rate:.1f} samples/sec - Avg: {avg_time:.1f}s/sample - Oracle prompts: {oracle_count} - ETA: {remaining/60:.1f}min")
            
            except Exception as e:
                error_msg = str(e)
                print(f"❌ Error processing sample {i}: {error_msg}")
                
                # Check for critical errors
                if "token" in error_msg.lower() or "connection" in error_msg.lower():
                    print(f"🛑 Critical error detected: {error_msg}")
                    print("💾 Saving current state and aborting for safe resume")
                    
                    # Save final checkpoint if needed
                    if checkpoint_results:
                        final_checkpoint = ((i - 1) // self.checkpoint_interval) + 1
                        self.save_checkpoint(checkpoint_results, final_checkpoint)
                    
                    # Exit gracefully
                    print(f"🔄 To resume, run the script again - it will automatically continue from sample {i}")
                    sys.exit(1)
                
                # Non-critical errors - continue processing
                continue
        
        # Save final checkpoint if needed
        if checkpoint_results:
            final_checkpoint = ((end_idx - 1) // self.checkpoint_interval) + 1
            self.save_checkpoint(checkpoint_results, final_checkpoint)
        
        # Consolidate final results
        self.consolidate_final_results()
        
        # Final summary with oracle performance metrics
        total_time = time.time() - self.start_time
        print(f"\n🎉 Oracle Evaluation Complete!")
        print(f"   Processed: {self.processed_count}/{end_idx-start_idx} samples")
        print(f"   Total time: {total_time/60:.1f} minutes")
        print(f"   Average rate: {self.processed_count/total_time:.1f} samples/sec")
        print(f"   Average time per sample: {self.performance_stats['avg_processing_time']:.1f}s")
        print(f"   Oracle prompts used: {self.performance_stats['oracle_prompts_used']}")
        
        if self.results:
            avg_wer = np.mean([getattr(r['metrics'], 'wer', 1.0) for r in self.results])
            avg_ser = np.mean([getattr(r['metrics'], 'ser', 1.0) for r in self.results])
            avg_slot_wer = np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in self.results])
            perfect_count = len([r for r in self.results if getattr(r['metrics'], 'wer', 1.0) == 0.0])
            print(f"   Average WER: {avg_wer:.3f} (with oracle prompts)")
            print(f"   Average SER: {avg_ser:.3f} (with oracle prompts)")
            print(f"   Average SlotWER: {avg_slot_wer:.3f} {'(disabled)' if not self.enable_slot_wer else ''}")
            print(f"   Perfect transcriptions: {perfect_count}/{len(self.results)} ({perfect_count/len(self.results)*100:.1f}%)")
        
        print(f"\n🔮 Oracle Testing Summary:")
        print(f"   Upper bound testing: Using ground truth as prompts")
        print(f"   Perfect context provided: {self.performance_stats['perfect_context_provided']}")
        
        print(f"\n🚀 Performance Summary:")
        print(f"   Slot-WER: {'ENABLED' if self.enable_slot_wer else 'DISABLED (3x speed boost)'}")
        print(f"   Auto-resume: {'ENABLED' if self.auto_resume else 'DISABLED'}")
        print(f"   Checkpoint interval: {self.checkpoint_interval} samples")
        print(f"   Logger stats: {self.logger.get_stats()}")


def main():
    """Main function with enhanced command line interface for oracle testing."""
    parser = argparse.ArgumentParser(description='Oracle ASR Evaluation Pipeline - Upper Bound Testing')
    parser.add_argument('--dataset-path', default='../final_hf_dataset/agentic_asr_normalized',
                       help='Path to HuggingFace dataset')
    parser.add_argument('--model-name', default='openai/whisper-small',
                       help='Whisper model name')
    parser.add_argument('--experiment-name', default='whisper-small-oracle',
                       help='Experiment name for oracle logging')
    parser.add_argument('--num-beams', type=int, default=1,
                       help='Number of beams for beam search (default: 1 for greedy)')
    parser.add_argument('--checkpoint-interval', type=int, default=500,
                       help='Save checkpoint every N samples (default: 500)')
    parser.add_argument('--start-idx', type=int, default=0,
                       help='Starting sample index (overridden by auto-resume)')
    parser.add_argument('--max-samples', type=int,
                       help='Maximum number of samples to process')
    parser.add_argument('--enable-slot-wer', action='store_true',
                       help='Enable slot-WER computation (slower, disabled by default)')
    parser.add_argument('--disable-auto-resume', action='store_true',
                       help='Disable automatic resume from checkpoints')
    parser.add_argument('--test-run', action='store_true',
                       help='Run on first 10 samples for testing oracle approach')
    
    args = parser.parse_args()
    
    # Test run override
    if args.test_run:
        args.max_samples = 10
        args.checkpoint_interval = 5
        print("🧪 Oracle test run mode: processing first 10 samples")
    
    # Create oracle evaluator
    evaluator = ComprehensiveASREvaluatorOracle(
        model_name=args.model_name,
        dataset_path=args.dataset_path,
        experiment_name=args.experiment_name,
        num_beams=args.num_beams,
        checkpoint_interval=args.checkpoint_interval,
        enable_slot_wer=args.enable_slot_wer,
        auto_resume=not args.disable_auto_resume
    )
    
    # Run oracle evaluation
    evaluator.run_evaluation(
        start_idx=args.start_idx,
        max_samples=args.max_samples
    )


if __name__ == "__main__":
    main()
