#!/usr/bin/env python3
"""
Comprehensive ASR Evaluation Pipeline - Adversarial Prompt Testing
Tests ASR robustness using intentionally incorrect domain prompts to measure performance degradation.
"""

import torch
import numpy as np
import json
import time
import argparse
import sys
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional

# Core imports
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_from_disk
import librosa

# Local imports - use improved versions
from compute_metrics_approx import ASRMetricsApprox, MetricsResultApprox
from asr_experiment_logger_improved import ASRExperimentLoggerImproved
from llm_text_normalizer import LLMTextNormalizer



class ComprehensiveASREvaluatorAdversarial:
    """
    Adversarial ASR evaluation pipeline testing domain confusion impact.
    
    Tests ASR robustness by using intentionally incorrect domain prompts:
    - Financial content gets "cooking recipes" prompts
    - Insurance content gets "sports commentary" prompts
    - Healthcare content gets "automotive repair" prompts
    - Technology content gets "gardening tips" prompts
    
    This measures how much ASR performance degrades when given misleading context.
    """
    
    def __init__(self, 
                 model_name: str = "openai/whisper-small",
                 dataset_path: str = "../final_hf_dataset/agentic_asr_normalized",
                 experiment_name: str = "whisper-small-adversarial",
                 num_beams: int = 1,
                 checkpoint_interval: int = 500,
                 enable_slot_wer: bool = False,
                 auto_resume: bool = True):
        
        self.model_name = model_name
        self.dataset_path = dataset_path
        self.experiment_name = experiment_name
        self.num_beams = num_beams
        self.checkpoint_interval = checkpoint_interval
        self.enable_slot_wer = enable_slot_wer
        self.auto_resume = auto_resume
        
        # Adversarial domain mapping - intentionally wrong domains
        self.adversarial_domain_map = {
            'financial': 'cooking recipes',
            'finance': 'cooking recipes',
            'insurance': 'sports commentary',
            'healthcare': 'automotive repair',
            'health': 'automotive repair',
            'technology': 'gardening tips',
            'tech': 'gardening tips',
            'business': 'music theory',
            'legal': 'weather forecasting',
            'education': 'fashion design',
            'retail': 'astronomy',
            'real_estate': 'marine biology',
            'travel': 'archaeology',
            'entertainment': 'chemistry',
            'default': 'abstract art discussion'  # fallback for unknown domains
        }
        
        # Initialize components
        self.model = None
        self.processor = None
        self.dataset = None
        self.logger = None
        self.metrics_calculator = None
        
        # Tracking variables
        self.processed_count = 0
        self.start_time = None
        self.results = []
        self.resume_index = 0
        
        # Performance tracking with adversarial-specific metrics
        self.performance_stats = {
            "total_samples": 0,
            "successful_samples": 0,
            "failed_samples": 0,
            "avg_processing_time": 0.0,
            "adversarial_testing": True,
            "slot_wer_disabled": not enable_slot_wer,
            "domain_confusion_count": 0,
            "adversarial_mappings_used": {}
        }
        
    def get_adversarial_prompt(self, actual_domain: str) -> Tuple[str, str]:
        """
        Generate adversarial prompt based on domain confusion mapping.
        
        Args:
            actual_domain: The real domain of the content
            
        Returns:
            Tuple of (adversarial_prompt, adversarial_domain_used)
        """
        # Normalize domain name
        domain_key = actual_domain.lower().strip()
        
        # Find adversarial domain
        adversarial_domain = self.adversarial_domain_map.get(
            domain_key, 
            self.adversarial_domain_map['default']
        )
        
        # Create adversarial prompt
        adversarial_prompt = f"This is about {adversarial_domain}"
        
        # Track usage
        if adversarial_domain not in self.performance_stats["adversarial_mappings_used"]:
            self.performance_stats["adversarial_mappings_used"][adversarial_domain] = 0
        self.performance_stats["adversarial_mappings_used"][adversarial_domain] += 1
        self.performance_stats["domain_confusion_count"] += 1
        
        return adversarial_prompt, adversarial_domain
        
    def initialize_components(self):
        """Initialize all pipeline components with adversarial testing setup."""
        print("🚀 Initializing Adversarial ASR Evaluation Pipeline...")
        print("⚠️  ADVERSARIAL MODE: Using intentionally incorrect domain prompts")
        
        # Load model and processor
        print(f"Loading Whisper model: {self.model_name}")
        self.processor = WhisperProcessor.from_pretrained(self.model_name)
        self.model = WhisperForConditionalGeneration.from_pretrained(
            self.model_name,
            torch_dtype=torch.float32,
            device_map="auto"
        )
        print("✅ Model loaded successfully")
        
        # Load dataset
        print(f"Loading dataset from: {self.dataset_path}")
        self.dataset = load_from_disk(self.dataset_path)
        print(f"✅ Dataset loaded: {len(self.dataset)} samples")
        
        # Display adversarial mapping strategy
        print("\n🎭 Adversarial Domain Mapping Strategy:")
        for actual, adversarial in self.adversarial_domain_map.items():
            if actual != 'default':
                print(f"   {actual} → {adversarial}")
        print(f"   (unknown domains) → {self.adversarial_domain_map['default']}")
        
        # Initialize shared Bedrock client for consistent session management
        print("\nInitializing shared Bedrock client...")
        from bedrock_claude import BedrockClaudeClient
        shared_bedrock_client = BedrockClaudeClient()
        print("✅ Shared Bedrock client initialized")
        
        # Initialize optimized metrics calculator
        print("Initializing OPTIMIZED ASR metrics calculator...")
        normalizer = LLMTextNormalizer(shared_bedrock_client)
        self.metrics_calculator = ASRMetricsApprox(
            normalizer, 
            shared_bedrock_client, 
            enable_slot_wer=self.enable_slot_wer
        )
        if self.enable_slot_wer:
            print("⚠️  Slot-WER enabled (slower performance)")
        else:
            print("🚀 Slot-WER disabled for 3x performance improvement")
        print("✅ Optimized metrics calculator initialized")
        
        # Initialize enhanced experiment logger
        print(f"Initializing enhanced experiment logger: {self.experiment_name}")
        self.logger = ASRExperimentLoggerImproved(
            experiment_id=self.experiment_name,
            output_format="jsonl",
            output_dir="asr_experiments",
            checkpoint_interval=self.checkpoint_interval
        )
        print("✅ Enhanced logger initialized")
        
        # Auto-resume detection
        if self.auto_resume:
            print("🔍 Checking for existing adversarial experiments to resume...")
            self.resume_index = self.logger.detect_resume_point(self.experiment_name)
            if self.resume_index > 0:
                print(f"🔄 Auto-resume enabled: starting from sample {self.resume_index}")
            else:
                print("🆕 No existing adversarial experiments found - starting fresh")
        
    def process_audio_with_adversarial_prompt(self, 
                                            audio_array: np.ndarray, 
                                            sample_rate: int, 
                                            adversarial_prompt: str) -> Tuple[List[str], List[float]]:
        """Process audio with Whisper using adversarial prompt conditioning."""
        
        # Resample if needed (Whisper expects 16kHz)
        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            audio_array = librosa.resample(
                audio_array, 
                orig_sr=sample_rate, 
                target_sr=target_sample_rate
            )
            sample_rate = target_sample_rate
        
        # Process audio inputs
        inputs = self.processor(
            audio_array,
            sampling_rate=sample_rate,
            return_tensors="pt"
        )
        
        # Move to device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Prepare adversarial prompt for Whisper's native conditioning
        generation_kwargs = {
            "input_features": inputs["input_features"],
            "max_new_tokens": 256,
            "do_sample": False,
            "num_beams": 1,  # Greedy decoding for speed
            "pad_token_id": self.processor.tokenizer.pad_token_id,
            "eos_token_id": self.processor.tokenizer.eos_token_id,
        }
        
        # Add adversarial prompt conditioning
        if adversarial_prompt and adversarial_prompt.strip():
            try:
                prompt_ids = self.processor.tokenizer(
                    adversarial_prompt,
                    return_tensors="pt",
                    add_special_tokens=False
                ).input_ids.to(self.model.device)
                
                # Only add if prompt_ids is valid
                if prompt_ids is not None and prompt_ids.numel() > 0:
                    generation_kwargs["prompt_ids"] = prompt_ids
            except Exception as e:
                print(f"  ⚠️ Adversarial prompt processing error: {e}, continuing without prompt")
        
        # Generate transcription with adversarial conditioning
        with torch.no_grad():
            try:
                generated_outputs = self.model.generate(**generation_kwargs)
            except Exception as e:
                print(f"  ⚠️ Generation error: {e}")
                # Simple fallback without any special parameters
                try:
                    generated_outputs = self.model.generate(
                        inputs["input_features"],
                        max_new_tokens=256
                    )
                except Exception as fallback_e:
                    print(f"  ❌ Fallback generation failed: {fallback_e}")
                    return [""], [0.0]
        
        # Decode the generated sequence
        if generated_outputs is not None and generated_outputs.numel() > 0:
            # Take only the first sequence (greedy decoding)
            if len(generated_outputs.shape) > 1:
                sequence = generated_outputs[0]
            else:
                sequence = generated_outputs
            
            # Decode to text
            transcription = self.processor.decode(sequence, skip_special_tokens=True)
            transcription = transcription.strip()
            
            # Return single transcription with confidence 1.0 (greedy decoding)
            return [transcription], [1.0]
        else:
            return [""], [0.0]
    
    def process_single_sample(self, sample: Dict[str, Any], sample_idx: int) -> Dict[str, Any]:
        """Process a single sample through the adversarial pipeline."""
        
        # Extract sample data
        utterance_id = sample.get('utterance_id', f'sample_{sample_idx}')
        actual_domain = sample.get('domain', 'unknown')
        voice = sample.get('voice', 'unknown')
        asr_difficulty = sample.get('asr_difficulty', 0.0)
        
        # Ground truth data
        truth = sample.get('truth', sample.get('text', ''))
        normalized_truth = sample.get('normalized_truth', truth)
        
        # Get original prompt for comparison
        original_prompt_field = sample.get('prompt', [])
        if isinstance(original_prompt_field, list) and len(original_prompt_field) > 0:
            original_prompt = ' '.join(original_prompt_field)
        elif isinstance(original_prompt_field, str):
            original_prompt = original_prompt_field
        else:
            original_prompt = ''
        
        # Generate adversarial prompt
        adversarial_prompt, adversarial_domain = self.get_adversarial_prompt(actual_domain)
        
        # Audio data
        audio_data = sample['audio']
        audio_array = np.array(audio_data['array'])
        sample_rate = audio_data['sampling_rate']
        
        print(f"  📝 Processing: {utterance_id} ({actual_domain}/{voice})")
        print(f"  🎭 Adversarial: {actual_domain} → {adversarial_domain}")
        
        # Process audio with adversarial prompt
        start_time = time.time()
        try:
            transcriptions, probabilities = self.process_audio_with_adversarial_prompt(
                audio_array, sample_rate, adversarial_prompt
            )
            processing_time = time.time() - start_time
            
            # Get best transcription
            best_transcription = transcriptions[0] if transcriptions else ""
            
            print(f"    🎯 Result: {best_transcription[:100]}...")
            
        except Exception as e:
            print(f"    ❌ ASR Error: {e}")
            transcriptions = [""]
            probabilities = [0.0]
            best_transcription = ""
            processing_time = 0.0
            
            # Check for critical errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise  # Re-raise critical errors
        
        # Compute optimized metrics
        try:
            # Check if normalized_truth is actually normalized
            if normalized_truth != truth:
                # Use pre-normalized truth
                metrics_result = self.metrics_calculator.compute_metrics(
                    ground_truth=normalized_truth,
                    predictions=transcriptions,
                    skip_norm_for_truth=True
                )
            else:
                # normalized_truth is same as truth, so normalize both
                metrics_result = self.metrics_calculator.compute_metrics(
                    ground_truth=truth,
                    predictions=transcriptions,
                    skip_norm_for_truth=False  # Normalize both
                )
            
            slot_status = "disabled" if not self.enable_slot_wer else f"{metrics_result.slot_wer:.3f}"
            print(f"    📊 WER: {metrics_result.wer:.3f}, SER: {metrics_result.ser:.3f}, SlotWER: {slot_status}")
            
        except Exception as e:
            print(f"    ❌ Metrics Error: {e}")
            
            # Check for critical errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise  # Re-raise critical errors
            
            # Create a mock MetricsResultApprox-like object
            class MockMetrics:
                def __init__(self):
                    self.wer = 1.0
                    self.ser = 1.0
                    self.slot_wer = 1.0
                    self.oracle_wer = 1.0
                    self.oracle_slot_wer = 1.0
                    self.normalized_prediction = best_transcription
                    self.normalized_truth = normalized_truth
                    self.normalized_predictions = [best_transcription]
                    self.nbest_match = False
                    self.truth_slots = []
                    self.prediction_slots = []
                    self.slot_matches = []
                    self.slot_mismatches = []
                    self.word_errors = {}
                    self.slot_errors = {}
                    self.position_metrics = []
            
            metrics_result = MockMetrics()
        
        # Compile complete result with adversarial-specific fields
        result = {
            'sample_index': sample_idx,
            'utterance_id': utterance_id,
            'actual_domain': actual_domain,
            'adversarial_domain': adversarial_domain,
            'voice': voice,
            'asr_difficulty': asr_difficulty,
            'original_prompt': original_prompt,
            'adversarial_prompt': adversarial_prompt,
            'ground_truth': truth,
            'normalized_truth': normalized_truth,
            'transcriptions': transcriptions,
            'probabilities': probabilities,
            'best_transcription': best_transcription,
            'processing_time': processing_time,
            'metrics': metrics_result,
            'timestamp': datetime.now().isoformat(),
            'adversarial_testing': True,
            'slot_wer_enabled': self.enable_slot_wer
        }
        
        return result
    
    def save_checkpoint(self, results: List[Dict[str, Any]], checkpoint_num: int):
        """Save checkpoint results using enhanced logger."""
        self.logger.save_checkpoint(results, checkpoint_num)
    
    def consolidate_final_results(self):
        """Consolidate all results and generate adversarial-specific analytics."""
        print("\n🔄 Consolidating adversarial evaluation results...")
        
        # Create final directory
        final_dir = Path(f"asr_experiments/{self.logger.experiment_id}")
        final_dir.mkdir(parents=True, exist_ok=True)
        
        # Save final JSONL
        final_jsonl = final_dir / f"{self.logger.experiment_id}_FINAL.jsonl"
        with open(final_jsonl, 'w') as f:
            for result in self.results:
                f.write(json.dumps(result, default=str) + '\n')
        
        # Generate final CSV
        final_csv = final_dir / f"{self.logger.experiment_id}_FINAL.csv"
        self.generate_adversarial_csv_summary(final_csv)
        
        # Generate adversarial analytics
        self.generate_adversarial_analytics()
        
        print(f"✅ Adversarial evaluation results saved to: {final_dir}")
    
    def generate_adversarial_csv_summary(self, csv_path: Path):
        """Generate CSV summary with adversarial-specific columns."""
        import csv
        
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            
            # Header with adversarial fields
            writer.writerow([
                'utterance_id', 'actual_domain', 'adversarial_domain', 'voice', 'asr_difficulty',
                'ground_truth', 'best_transcription', 'wer', 'ser', 'slot_wer',
                'processing_time', 'confidence_score', 'original_prompt', 'adversarial_prompt'
            ])
            
            # Data rows
            for result in self.results:
                metrics = result['metrics']
                writer.writerow([
                    result['utterance_id'],
                    result['actual_domain'],
                    result['adversarial_domain'],
                    result['voice'],
                    result['asr_difficulty'],
                    result['ground_truth'],
                    result['best_transcription'],
                    metrics.wer if hasattr(metrics, 'wer') else getattr(metrics, 'wer', 1.0),
                    metrics.ser if hasattr(metrics, 'ser') else getattr(metrics, 'ser', 1.0),
                    metrics.slot_wer if hasattr(metrics, 'slot_wer') else getattr(metrics, 'slot_wer', 1.0),
                    result['processing_time'],
                    result['probabilities'][0] if result['probabilities'] else 0.0,
                    result['original_prompt'],
                    result['adversarial_prompt']
                ])
    
    def generate_adversarial_analytics(self):
        """Generate comprehensive adversarial-specific analytics."""
        analytics_dir = Path(f"asr_experiments/{self.logger.experiment_id}/final_analytics")
        analytics_dir.mkdir(parents=True, exist_ok=True)
        
        # Overall adversarial performance
        wers = [getattr(r['metrics'], 'wer', 1.0) for r in self.results]
        sers = [getattr(r['metrics'], 'ser', 1.0) for r in self.results]
        slot_wers = [getattr(r['metrics'], 'slot_wer', 1.0) for r in self.results]
        processing_times = [r['processing_time'] for r in self.results]
        
        overall_stats = {
            'experiment_type': 'adversarial_domain_confusion',
            'total_samples': len(self.results),
            'avg_wer': np.mean(wers),
            'avg_ser': np.mean(sers),
            'avg_slot_wer': np.mean(slot_wers),
            'std_wer': np.std(wers),
            'std_ser': np.std(sers),
            'std_slot_wer': np.std(slot_wers),
            'total_runtime': time.time() - self.start_time if self.start_time else 0,
            'avg_processing_time': np.mean(processing_times),
            'min_processing_time': np.min(processing_times),
            'max_processing_time': np.max(processing_times),
            'adversarial_testing': True,
            'slot_wer_enabled': self.enable_slot_wer,
            'checkpoint_interval': self.checkpoint_interval,
            'auto_resume_enabled': self.auto_resume,
            'resume_index': self.resume_index,
            'domain_confusion_mappings': self.adversarial_domain_map,
            'adversarial_mappings_used': self.performance_stats["adversarial_mappings_used"]
        }
        
        with open(analytics_dir / 'overall_adversarial_performance.json', 'w') as f:
            json.dump(overall_stats, f, indent=2, default=str)
        
        # Adversarial domain breakdown
        adversarial_domain_stats = {}
        for adv_domain in set(r['adversarial_domain'] for r in self.results):
            adv_results = [r for r in self.results if r['adversarial_domain'] == adv_domain]
            adv_wers = [getattr(r['metrics'], 'wer', 1.0) for r in adv_results]
            adv_times = [r['processing_time'] for r in adv_results]
            
            # Get actual domains that were mapped to this adversarial domain
            actual_domains_mapped = list(set(r['actual_domain'] for r in adv_results))
            
            adversarial_domain_stats[adv_domain] = {
                'count': len(adv_results),
                'avg_wer': np.mean(adv_wers),
                'avg_ser': np.mean([getattr(r['metrics'], 'ser', 1.0) for r in adv_results]),
                'avg_slot_wer': np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in adv_results]),
                'avg_processing_time': np.mean(adv_times),
                'actual_domains_confused': actual_domains_mapped
            }
        
        with open(analytics_dir / 'adversarial_domain_breakdown.json', 'w') as f:
            json.dump(adversarial_domain_stats, f, indent=2, default=str)
        
        # Actual domain vs adversarial impact analysis
        domain_confusion_impact = {}
        for actual_domain in set(r['actual_domain'] for r in self.results):
            domain_results = [r for r in self.results if r['actual_domain'] == actual_domain]
            domain_wers = [getattr(r['metrics'], 'wer', 1.0) for r in domain_results]
            
            # Get the adversarial domain used for this actual domain
            adversarial_domains_used = list(set(r['adversarial_domain'] for r in domain_results))
            
            domain_confusion_impact[actual_domain] = {
                'count': len(domain_results),
                'avg_wer': np.mean(domain_wers),
                'avg_ser': np.mean([getattr(r['metrics'], 'ser', 1.0) for r in domain_results]),
                'avg_slot_wer': np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in domain_results]),
                'adversarial_domains_used': adversarial_domains_used,
                'confusion_strategy': f"{actual_domain} → {adversarial_domains_used[0] if adversarial_domains_used else 'unknown'}"
            }
        
        with open(analytics_dir / 'domain_confusion_impact.json', 'w') as f:
            json.dump(domain_confusion_impact, f, indent=2, default=str)
        
        # Voice breakdown (same as original but with adversarial context)
        voice_stats = {}
        for voice in set(r['voice'] for r in self.results):
            voice_results = [r for r in self.results if r['voice'] == voice]
            voice_wers = [getattr(r['metrics'], 'wer', 1.0) for r in voice_results]
            voice_times = [r['processing_time'] for r in voice_results]
            voice_stats[voice] = {
                'count': len(voice_results),
                'avg_wer': np.mean(voice_wers),
                'avg_ser': np.mean([getattr(r['metrics'], 'ser', 1.0) for r in voice_results]),
                'avg_slot_wer': np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in voice_results]),
                'avg_processing_time': np.mean(voice_times),
                'adversarial_testing': True
            }
        
        with open(analytics_dir / 'voice_breakdown.json', 'w') as f:
            json.dump(voice_stats, f, indent=2, default=str)
        
        print("📊 Adversarial analytics generated with domain confusion analysis")
    
    def run_evaluation(self, start_idx: int = 0, max_samples: Optional[int] = None):
        """Run the complete adversarial evaluation pipeline with auto-resume."""
        self.start_time = time.time()
        
        # Initialize components
        self.initialize_components()
        
        # Use resume index if auto-resume is enabled
        if self.auto_resume and self.resume_index > start_idx:
            start_idx = self.resume_index
            print(f"🔄 Resuming adversarial evaluation from index {start_idx}")
        
        # Determine sample range
        total_samples = len(self.dataset)
        end_idx = min(start_idx + max_samples, total_samples) if max_samples else total_samples
        
        print(f"\n🎯 Starting ADVERSARIAL evaluation: samples {start_idx} to {end_idx-1} ({end_idx-start_idx} total)")
        print(f"🎭 Adversarial strategy: Domain confusion using incorrect prompts")
        print(f"🚀 Performance optimizations enabled:")
        print(f"   - Slot-WER: {'ENABLED (slower)' if self.enable_slot_wer else 'DISABLED (3x faster)'}")
        print(f"   - Checkpoint interval: {self.checkpoint_interval} samples")
        print(f"   - Auto-resume: {'ENABLED' if self.auto_resume else 'DISABLED'}")
        
        # Process samples
        checkpoint_results = []
        
        for i in range(start_idx, end_idx):
            sample = self.dataset[i]
            
            try:
                # Check for critical errors before processing
                if self.logger.should_abort():
                    print(f"\n🛑 Critical error threshold reached - aborting at sample {i}")
                    print("💾 Current state saved for resume")
                    break
                
                # Log individual experiment with proper context
                with self.logger.start_experiment(sample, i) as exp:
                    # Process the sample with adversarial prompts
                    result = self.process_single_sample(sample, i)
                    self.results.append(result)
                    checkpoint_results.append(result)
                    self.processed_count += 1
                    
                    # Update performance stats
                    self.performance_stats["total_samples"] += 1
                    if not result.get('errors', []):
                        self.performance_stats["successful_samples"] += 1
                    else:
                        self.performance_stats["failed_samples"] += 1
                    
                    # Update average processing time
                    if result['processing_time'] > 0:
                        current_avg = self.performance_stats["avg_processing_time"]
                        count = self.performance_stats["total_samples"]
                        new_avg = (current_avg * (count - 1) + result['processing_time']) / count
                        self.performance_stats["avg_processing_time"] = new_avg
                    
                    # Log ASR step with adversarial context
                    exp.log_asr_step(
                        model_name=self.model_name,
                        model_params={
                            'num_beams': self.num_beams,
                            'max_new_tokens': 256,
                            'temperature': 1.0,
                            'adversarial_testing': True,
                            'adversarial_prompt': result['adversarial_prompt']
                        },
                        predictions=result['transcriptions'],
                        confidence_scores=result['probabilities'],
                        processing_time=result['processing_time']
                    )
                    
                    # Log normalization step
                    exp.log_normalization_step(
                        original_truth=result['ground_truth'],
                        normalized_truth=result['normalized_truth']
                    )
                    
                    # Log metrics step with the actual MetricsResult object
                    exp.log_metrics_step(result['metrics'])
                
                # Checkpoint save
                if (i + 1) % self.checkpoint_interval == 0:
                    checkpoint_num = (i + 1) // self.checkpoint_interval
                    self.save_checkpoint(checkpoint_results, checkpoint_num)
                    checkpoint_results = []  # Reset for next checkpoint
                    
                    # Progress update with adversarial performance info
                    elapsed = time.time() - self.start_time
                    rate = self.processed_count / elapsed
                    remaining = (end_idx - i - 1) / rate if rate > 0 else 0
                    avg_time = self.performance_stats["avg_processing_time"]
                    confusion_count = self.performance_stats["domain_confusion_count"]
                    print(f"⏱️  Progress: {i+1}/{end_idx} ({(i+1)/end_idx*100:.1f}%) - {rate:.1f} samples/sec - Avg: {avg_time:.1f}s/sample - Confusions: {confusion_count} - ETA: {remaining/60:.1f}min")
            
            except Exception as e:
                error_msg = str(e)
                print(f"❌ Error processing sample {i}: {error_msg}")
                
                # Check for critical errors
                if "token" in error_msg.lower() or "connection" in error_msg.lower():
                    print(f"🛑 Critical error detected: {error_msg}")
                    print("💾 Saving current state and aborting for safe resume")
                    
                    # Save final checkpoint if needed
                    if checkpoint_results:
                        final_checkpoint = ((i - 1) // self.checkpoint_interval) + 1
                        self.save_checkpoint(checkpoint_results, final_checkpoint)
                    
                    # Exit gracefully
                    print(f"🔄 To resume, run the script again - it will automatically continue from sample {i}")
                    sys.exit(1)
                
                # Non-critical errors - continue processing
                continue
        
        # Save final checkpoint if needed
        if checkpoint_results:
            final_checkpoint = ((end_idx - 1) // self.checkpoint_interval) + 1
            self.save_checkpoint(checkpoint_results, final_checkpoint)
        
        # Consolidate final results
        self.consolidate_final_results()
        
        # Final summary with adversarial performance metrics
        total_time = time.time() - self.start_time
        print(f"\n🎉 Adversarial Evaluation Complete!")
        print(f"   Processed: {self.processed_count}/{end_idx-start_idx} samples")
        print(f"   Total time: {total_time/60:.1f} minutes")
        print(f"   Average rate: {self.processed_count/total_time:.1f} samples/sec")
        print(f"   Average time per sample: {self.performance_stats['avg_processing_time']:.1f}s")
        print(f"   Domain confusions applied: {self.performance_stats['domain_confusion_count']}")
        
        if self.results:
            avg_wer = np.mean([getattr(r['metrics'], 'wer', 1.0) for r in self.results])
            avg_ser = np.mean([getattr(r['metrics'], 'ser', 1.0) for r in self.results])
            avg_slot_wer = np.mean([getattr(r['metrics'], 'slot_wer', 1.0) for r in self.results])
            print(f"   Average WER: {avg_wer:.3f} (with adversarial prompts)")
            print(f"   Average SER: {avg_ser:.3f} (with adversarial prompts)")
            print(f"   Average SlotWER: {avg_slot_wer:.3f} {'(disabled)' if not self.enable_slot_wer else ''}")
        
        print(f"\n🎭 Adversarial Testing Summary:")
        print(f"   Domain confusion strategy: Intentionally incorrect domain prompts")
        print(f"   Mappings used: {len(self.performance_stats['adversarial_mappings_used'])} different adversarial domains")
        for adv_domain, count in self.performance_stats["adversarial_mappings_used"].items():
            print(f"     - '{adv_domain}': {count} samples")
        
        print(f"\n🚀 Performance Summary:")
        print(f"   Slot-WER: {'ENABLED' if self.enable_slot_wer else 'DISABLED (3x speed boost)'}")
        print(f"   Auto-resume: {'ENABLED' if self.auto_resume else 'DISABLED'}")
        print(f"   Checkpoint interval: {self.checkpoint_interval} samples")
        print(f"   Logger stats: {self.logger.get_stats()}")


def main():
    """Main function with enhanced command line interface for adversarial testing."""
    parser = argparse.ArgumentParser(description='Adversarial ASR Evaluation Pipeline - Domain Confusion Testing')
    parser.add_argument('--dataset-path', default='../final_hf_dataset/agentic_asr_normalized',
                       help='Path to HuggingFace dataset')
    parser.add_argument('--model-name', default='openai/whisper-small',
                       help='Whisper model name')
    parser.add_argument('--experiment-name', default='whisper-small-adversarial',
                       help='Experiment name for adversarial logging')
    parser.add_argument('--num-beams', type=int, default=1,
                       help='Number of beams for beam search (default: 1 for greedy)')
    parser.add_argument('--checkpoint-interval', type=int, default=500,
                       help='Save checkpoint every N samples (default: 500)')
    parser.add_argument('--start-idx', type=int, default=0,
                       help='Starting sample index (overridden by auto-resume)')
    parser.add_argument('--max-samples', type=int,
                       help='Maximum number of samples to process')
    parser.add_argument('--enable-slot-wer', action='store_true',
                       help='Enable slot-WER computation (slower, disabled by default)')
    parser.add_argument('--disable-auto-resume', action='store_true',
                       help='Disable automatic resume from checkpoints')
    parser.add_argument('--test-run', action='store_true',
                       help='Run on first 10 samples for testing adversarial approach')
    
    args = parser.parse_args()
    
    # Test run override
    if args.test_run:
        args.max_samples = 10
        args.checkpoint_interval = 5
        print("🧪 Adversarial test run mode: processing first 10 samples")
    
    # Create adversarial evaluator
    evaluator = ComprehensiveASREvaluatorAdversarial(
        model_name=args.model_name,
        dataset_path=args.dataset_path,
        experiment_name=args.experiment_name,
        num_beams=args.num_beams,
        checkpoint_interval=args.checkpoint_interval,
        enable_slot_wer=args.enable_slot_wer,
        auto_resume=not args.disable_auto_resume
    )
    
    # Run adversarial evaluation
    evaluator.run_evaluation(
        start_idx=args.start_idx,
        max_samples=args.max_samples
    )


if __name__ == "__main__":
    main()
