#!/usr/bin/env python3
"""
Enhanced ASR Experiment Logger with State Management and Auto-Resume
Optimized for robust long-running experiments with interruption recovery.
"""

import json
import csv
import sqlite3
import os
import time
import glob
from datetime import datetime
from typing import Dict, List, Any, Optional, Union
from dataclasses import dataclass, asdict
from contextlib import contextmanager
import logging
from pathlib import Path


@dataclass
class ExperimentState:
    """Track experiment state for resume functionality."""
    experiment_id: str
    last_processed_index: int
    total_samples: int
    start_time: str
    last_checkpoint_time: str
    processing_errors: List[Dict[str, Any]]
    performance_stats: Dict[str, float]


class ExperimentContextImproved:
    """Enhanced context manager for ASR experiments with error resilience."""
    
    def __init__(self, logger: 'ASRExperimentLoggerImproved', dataset_record: dict, sample_index: int):
        self.logger = logger
        self.sample_index = sample_index
        self.start_time = time.time()
        self.step_times = {}
        self.current_step_start = None
        
        # Initialize experiment data
        self.data = {
            "experiment_id": logger.experiment_id,
            "sample_index": sample_index,
            "timestamp": datetime.now().isoformat(),
            "dataset_info": self._extract_dataset_info(dataset_record),
            "asr_model": {},
            "predictions": {},
            "ground_truth": {},
            "metrics": {},
            "processing_times": {},
            "errors": [],
            "metadata": {
                "logger_version": "2.0.0",
                "performance_optimized": True,
                "auto_resume_enabled": True
            }
        }
    
    def _extract_dataset_info(self, dataset_record: dict) -> dict:
        """Extract complete dataset metadata including all HuggingFace dataset fields."""
        audio_info = {}
        if "audio" in dataset_record and dataset_record["audio"]:
            audio_data = dataset_record["audio"]
            if isinstance(audio_data, dict):
                sampling_rate = audio_data.get("sampling_rate", 22050)
                audio_array = audio_data.get("array", [])
                audio_info = {
                    "sampling_rate": sampling_rate,
                    "audio_length_seconds": len(audio_array) / sampling_rate if audio_array else None,
                    "audio_samples": len(audio_array) if audio_array else None,
                    "audio_path": audio_data.get("path")  # Include audio file path
                }
        
        return {
            # Core identifiers
            "utterance_id": dataset_record.get("utterance_id"),
            "domain": dataset_record.get("domain"),
            "voice": dataset_record.get("voice"),
            "asr_difficulty": dataset_record.get("asr_difficulty"),
            
            # Content fields
            "truth": dataset_record.get("truth"),
            "normalized_truth": dataset_record.get("normalized_truth"),
            "text": dataset_record.get("text"),  # Full original text
            "utterance": dataset_record.get("utterance"),  # Full utterance
            "raw_form": dataset_record.get("raw_form"),  # Raw form
            
            # Additional metadata fields
            "audio_wav_path": dataset_record.get("audio_wav_path"),  # Original audio file path
            "profile": dataset_record.get("profile"),  # Speaker profile
            "sample_rate": dataset_record.get("sample_rate"),  # Original sample rate
            
            # Structured content
            "sentences": dataset_record.get("sentences", []),  # List of sentences
            "prompt": dataset_record.get("prompt", []),  # Prompt context
            
            # Audio info (without the actual audio data)
            "audio_info": audio_info
        }
    
    def _start_step_timer(self, step_name: str):
        """Start timing a processing step."""
        self.current_step_start = time.time()
        self.current_step_name = step_name
    
    def _end_step_timer(self):
        """End timing the current processing step."""
        if self.current_step_start:
            duration = time.time() - self.current_step_start
            self.step_times[self.current_step_name] = duration
            self.current_step_start = None
    
    def log_asr_step(self, model_name: str, model_params: dict, 
                     predictions: Union[str, List[str]], 
                     confidence_scores: Optional[List[float]] = None,
                     processing_time: Optional[float] = None) -> 'ExperimentContextImproved':
        """Log ASR model inference step."""
        self._start_step_timer("asr_inference")
        
        try:
            # Ensure predictions is a list
            if isinstance(predictions, str):
                predictions = [predictions]
            
            self.data["asr_model"] = {
                "model_name": model_name,
                "model_params": model_params,
                "inference_timestamp": datetime.now().isoformat()
            }
            
            self.data["predictions"] = {
                "nbest_list": predictions,
                "confidence_scores": confidence_scores or [],
                "nbest_size": len(predictions)
            }
            
            if processing_time:
                self.step_times["asr_inference"] = processing_time
            
        except Exception as e:
            self.data["errors"].append({
                "step": "asr_inference",
                "error": str(e),
                "timestamp": datetime.now().isoformat()
            })
            # Re-raise critical errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise
        finally:
            self._end_step_timer()
        
        return self
    
    def log_normalization_step(self, original_truth: str, normalized_truth: str,
                              processing_time: Optional[float] = None) -> 'ExperimentContextImproved':
        """Log text normalization step."""
        self._start_step_timer("normalization")
        
        try:
            self.data["ground_truth"] = {
                "original": original_truth,
                "normalized": normalized_truth,
                "normalization_applied": original_truth != normalized_truth,
                "normalization_timestamp": datetime.now().isoformat()
            }
            
            if processing_time:
                self.step_times["normalization"] = processing_time
                
        except Exception as e:
            self.data["errors"].append({
                "step": "normalization",
                "error": str(e),
                "timestamp": datetime.now().isoformat()
            })
            # Re-raise LLM connectivity errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise
        finally:
            self._end_step_timer()
        
        return self
    
    def log_metrics_step(self, metrics_result, processing_time: Optional[float] = None) -> 'ExperimentContextImproved':
        """Log metrics computation step with enhanced normalized predictions tracking."""
        self._start_step_timer("metrics_computation")
        
        try:
            # Extract core metrics
            self.data["metrics"] = {
                "wer": getattr(metrics_result, 'wer', None),
                "slot_wer": getattr(metrics_result, 'slot_wer', 1.0),  # Default 1.0 for disabled
                "ser": getattr(metrics_result, 'ser', None),
                "oracle_wer": getattr(metrics_result, 'oracle_wer', None),
                "oracle_slot_wer": getattr(metrics_result, 'oracle_slot_wer', 1.0),  # Default 1.0 for disabled
                "nbest_match": getattr(metrics_result, 'nbest_match', None),
                "match_position": getattr(metrics_result, 'match_position', None),
                "nbest_size": getattr(metrics_result, 'nbest_size', None),
                "computation_timestamp": datetime.now().isoformat(),
                "performance_optimized": True  # Flag to indicate this used optimized metrics
            }
            
            # Extract normalized predictions (normalized nbest list) and add to predictions section
            normalized_predictions = getattr(metrics_result, 'normalized_predictions', [])
            if normalized_predictions and "predictions" in self.data:
                self.data["predictions"]["normalized_nbest_list"] = normalized_predictions
            elif normalized_predictions:
                # Create predictions section if it doesn't exist
                self.data["predictions"] = {
                    "normalized_nbest_list": normalized_predictions
                }
            
            # Also extract normalized truth for completeness
            normalized_truth = getattr(metrics_result, 'normalized_truth', None)
            if normalized_truth and "ground_truth" in self.data:
                # Add to existing ground_truth section if not already there
                if "normalized" not in self.data["ground_truth"]:
                    self.data["ground_truth"]["normalized"] = normalized_truth
            elif normalized_truth:
                # Create ground_truth section if it doesn't exist
                self.data["ground_truth"] = {
                    "normalized": normalized_truth
                }
            
            # Simplified slot information (mostly empty due to performance optimization)
            self.data["slots"] = {
                "truth_slots": getattr(metrics_result, 'truth_slots', []),
                "prediction_slots": getattr(metrics_result, 'prediction_slots', []),
                "slot_computation_disabled": not getattr(metrics_result, 'truth_slots', []),  # Flag
                "performance_note": "Slot extraction disabled for 3x performance improvement"
            }
            
            # Extract error analysis if available
            if hasattr(metrics_result, 'word_errors'):
                self.data["error_analysis"] = {
                    "word_errors": getattr(metrics_result, 'word_errors', {}),
                }
            
            # Extract per-position metrics if available (simplified)
            if hasattr(metrics_result, 'position_metrics'):
                self.data["position_metrics"] = getattr(metrics_result, 'position_metrics', [])
            
            if processing_time:
                self.step_times["metrics_computation"] = processing_time
                
        except Exception as e:
            self.data["errors"].append({
                "step": "metrics_computation",
                "error": str(e),
                "timestamp": datetime.now().isoformat()
            })
            # Re-raise LLM connectivity errors
            if "token" in str(e).lower() or "connection" in str(e).lower():
                raise
        finally:
            self._end_step_timer()
        
        return self
    
    def add_error(self, step: str, error: str, critical: bool = False) -> 'ExperimentContextImproved':
        """Add an error to the experiment log with criticality flag."""
        error_entry = {
            "step": step,
            "error": error,
            "critical": critical,
            "timestamp": datetime.now().isoformat()
        }
        self.data["errors"].append(error_entry)
        
        # Update logger's error tracking
        self.logger._track_error(error_entry)
        
        return self
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Finalize and log the experiment when context exits."""
        # Calculate total processing time
        total_time = time.time() - self.start_time
        
        # Update processing times
        self.data["processing_times"] = {
            **self.step_times,
            "total": total_time
        }
        
        # Log any exception that occurred
        if exc_type:
            critical_error = "token" in str(exc_val).lower() or "connection" in str(exc_val).lower()
            self.data["errors"].append({
                "step": "context_exit",
                "error": f"{exc_type.__name__}: {exc_val}",
                "critical": critical_error,
                "timestamp": datetime.now().isoformat()
            })
            
            # Update logger's error tracking
            if critical_error:
                self.logger._track_critical_error(f"{exc_type.__name__}: {exc_val}")
        
        # Write the complete record
        self.logger._write_record(self.data)
        
        # Update experiment state
        self.logger._update_state(self.sample_index, self.data)
        
        # Don't suppress exceptions
        return False


class ASRExperimentLoggerImproved:
    """
    Enhanced ASR experiment logger with auto-resume and robust error handling.
    
    Key improvements:
    - Automatic resume from checkpoints
    - State persistence for interruption recovery
    - Critical error detection and graceful shutdown
    - Performance-optimized logging
    - Reduced checkpoint frequency (500 samples default)
    """
    
    def __init__(self, experiment_id: str, output_format: str = "jsonl", 
                 output_dir: str = "asr_experiments", checkpoint_interval: int = 500):
        """
        Initialize the enhanced ASR experiment logger.
        
        Args:
            experiment_id: Base experiment identifier
            output_format: Output format ("jsonl", "csv", "sqlite")
            output_dir: Directory to store experiment logs
            checkpoint_interval: Save checkpoint every N samples (default: 500)
        """
        self.base_experiment_id = experiment_id
        self.experiment_id = f"{experiment_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.output_format = output_format.lower()
        self.output_dir = output_dir
        self.checkpoint_interval = checkpoint_interval
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Setup logging
        self.logger = logging.getLogger(__name__)
        
        # Initialize output file/database
        self._initialize_output()
        
        # Enhanced experiment tracking
        self.stats = {
            "total_experiments": 0,
            "successful_experiments": 0,
            "failed_experiments": 0,
            "critical_errors": 0,
            "start_time": datetime.now().isoformat(),
            "performance_optimized": True
        }
        
        # State management for auto-resume
        self.state = ExperimentState(
            experiment_id=self.experiment_id,
            last_processed_index=-1,
            total_samples=0,
            start_time=datetime.now().isoformat(),
            last_checkpoint_time=datetime.now().isoformat(),
            processing_errors=[],
            performance_stats={}
        )
        
        # Error tracking for critical failure detection
        self.recent_errors = []
        self.critical_error_count = 0
    
    def detect_resume_point(self, experiment_base_id: str) -> int:
        """
        Detect if there are existing checkpoints and determine resume point.
        
        Args:
            experiment_base_id: Base experiment ID to look for
            
        Returns:
            Index to resume from (0 if no checkpoints found)
        """
        try:
            # Look for existing experiment directories
            pattern = os.path.join(self.output_dir, f"{experiment_base_id}_*")
            existing_dirs = glob.glob(pattern)
            
            if not existing_dirs:
                self.logger.info("No existing experiments found - starting fresh")
                return 0
            
            # Find the most recent experiment
            latest_dir = max(existing_dirs, key=os.path.getctime)
            experiment_name = os.path.basename(latest_dir)
            
            # Look for checkpoint files
            checkpoint_pattern = os.path.join(latest_dir, "checkpoints", "checkpoint_*.jsonl")
            checkpoint_files = glob.glob(checkpoint_pattern)
            
            if not checkpoint_files:
                self.logger.info(f"No checkpoints found in {latest_dir} - starting fresh")
                return 0
            
            # Find the highest checkpoint number
            checkpoint_numbers = []
            for checkpoint_file in checkpoint_files:
                try:
                    filename = os.path.basename(checkpoint_file)
                    # Extract number from checkpoint_XXXX.jsonl
                    number_str = filename.replace("checkpoint_", "").replace(".jsonl", "")
                    checkpoint_numbers.append(int(number_str))
                except ValueError:
                    continue
            
            if not checkpoint_numbers:
                self.logger.info("No valid checkpoint numbers found - starting fresh")
                return 0
            
            highest_checkpoint = max(checkpoint_numbers)
            resume_index = highest_checkpoint * self.checkpoint_interval
            
            self.logger.info(f"Found existing experiment: {experiment_name}")
            self.logger.info(f"Highest checkpoint: {highest_checkpoint}")
            self.logger.info(f"Resuming from sample index: {resume_index}")
            
            # Update our experiment ID to continue the existing one
            self.experiment_id = experiment_name
            self._initialize_output()  # Re-initialize with existing experiment ID
            
            return resume_index
            
        except Exception as e:
            self.logger.warning(f"Error detecting resume point: {e}")
            return 0
    
    def _initialize_output(self):
        """Initialize the output file or database."""
        if self.output_format == "jsonl":
            self.output_file = os.path.join(self.output_dir, f"{self.experiment_id}.jsonl")
        elif self.output_format == "csv":
            self.output_file = os.path.join(self.output_dir, f"{self.experiment_id}.csv")
            self._csv_headers_written = os.path.exists(self.output_file)  # Check if file exists
        elif self.output_format == "sqlite":
            self.output_file = os.path.join(self.output_dir, f"{self.experiment_id}.db")
            self._initialize_sqlite()
        else:
            raise ValueError(f"Unsupported output format: {self.output_format}")
    
    def _initialize_sqlite(self):
        """Initialize SQLite database with experiment table."""
        conn = sqlite3.connect(self.output_file)
        cursor = conn.cursor()
        
        # Create experiments table with enhanced schema
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS experiments (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                experiment_id TEXT,
                sample_index INTEGER,
                timestamp TEXT,
                utterance_id TEXT,
                domain TEXT,
                voice TEXT,
                asr_difficulty REAL,
                model_name TEXT,
                wer REAL,
                slot_wer REAL,
                ser REAL,
                oracle_wer REAL,
                nbest_match BOOLEAN,
                match_position INTEGER,
                processing_time_total REAL,
                error_count INTEGER,
                critical_error_count INTEGER,
                original_nbest_list TEXT,      -- JSON array of original predictions
                normalized_nbest_list TEXT,   -- JSON array of normalized predictions
                performance_optimized BOOLEAN, -- Flag for optimized processing
                full_data TEXT  -- JSON blob with complete data
            )
        """)
        
        # Create state table for resume functionality
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS experiment_state (
                experiment_id TEXT PRIMARY KEY,
                last_processed_index INTEGER,
                total_samples INTEGER,
                start_time TEXT,
                last_checkpoint_time TEXT,
                critical_errors INTEGER,
                state_data TEXT  -- JSON blob with full state
            )
        """)
        
        conn.commit()
        conn.close()
    
    def start_experiment(self, dataset_record: dict, sample_index: int) -> ExperimentContextImproved:
        """
        Start a new experiment context with sample index tracking.
        
        Args:
            dataset_record: Complete dataset record with all metadata
            sample_index: Index of the sample being processed
            
        Returns:
            ExperimentContextImproved for logging experiment steps
        """
        return ExperimentContextImproved(self, dataset_record, sample_index)
    
    def _track_error(self, error_entry: Dict[str, Any]):
        """Track errors for critical failure detection."""
        self.recent_errors.append(error_entry)
        
        # Keep only recent errors (last 10)
        if len(self.recent_errors) > 10:
            self.recent_errors.pop(0)
        
        if error_entry.get("critical", False):
            self.critical_error_count += 1
    
    def _track_critical_error(self, error_message: str):
        """Track critical errors that should trigger immediate shutdown."""
        self.critical_error_count += 1
        self.logger.error(f"Critical error detected: {error_message}")
        
        # Save current state immediately
        self._save_state()
        
        # If we have multiple critical errors, recommend shutdown
        if self.critical_error_count >= 3:
            self.logger.critical("Multiple critical errors detected - recommend immediate shutdown")
    
    def should_abort(self) -> bool:
        """Check if experiment should abort due to critical errors."""
        return self.critical_error_count >= 3
    
    def _write_record(self, data: dict):
        """Write experiment record to output with enhanced error handling."""
        try:
            if self.output_format == "jsonl":
                self._write_jsonl(data)
            elif self.output_format == "csv":
                self._write_csv(data)
            elif self.output_format == "sqlite":
                self._write_sqlite(data)
            
            # Update statistics
            self.stats["total_experiments"] += 1
            if not data.get("errors"):
                self.stats["successful_experiments"] += 1
            else:
                self.stats["failed_experiments"] += 1
                
        except Exception as e:
            self.logger.error(f"Failed to write experiment record: {e}")
            # Don't let logging failures stop the experiment
    
    def _write_jsonl(self, data: dict):
        """Write record to JSONL file."""
        with open(self.output_file, 'a', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False)
            f.write('\n')
    
    def _write_csv(self, data: dict):
        """Write record to CSV file with resume support."""
        # Flatten the nested data structure
        flat_data = self._flatten_dict(data)
        
        # Write headers if first record
        if not self._csv_headers_written:
            with open(self.output_file, 'w', newline='', encoding='utf-8') as f:
                writer = csv.DictWriter(f, fieldnames=flat_data.keys())
                writer.writeheader()
            self._csv_headers_written = True
        
        # Write data
        with open(self.output_file, 'a', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=flat_data.keys())
            writer.writerow(flat_data)
    
    def _write_sqlite(self, data: dict):
        """Write record to SQLite database with enhanced schema."""
        conn = sqlite3.connect(self.output_file)
        cursor = conn.cursor()
        
        # Extract key fields for structured columns
        dataset_info = data.get("dataset_info", {})
        metrics = data.get("metrics", {})
        processing_times = data.get("processing_times", {})
        asr_model = data.get("asr_model", {})
        
        # Extract predictions data
        predictions = data.get("predictions", {})
        original_nbest = predictions.get("nbest_list", [])
        normalized_nbest = predictions.get("normalized_nbest_list", [])
        
        # Count critical errors
        critical_error_count = sum(1 for error in data.get("errors", []) if error.get("critical", False))
        
        cursor.execute("""
            INSERT INTO experiments (
                experiment_id, sample_index, timestamp, utterance_id, domain, voice, asr_difficulty,
                model_name, wer, slot_wer, ser, oracle_wer, nbest_match, match_position,
                processing_time_total, error_count, critical_error_count, original_nbest_list, 
                normalized_nbest_list, performance_optimized, full_data
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            data.get("experiment_id"),
            data.get("sample_index"),
            data.get("timestamp"),
            dataset_info.get("utterance_id"),
            dataset_info.get("domain"),
            dataset_info.get("voice"),
            dataset_info.get("asr_difficulty"),
            asr_model.get("model_name"),
            metrics.get("wer"),
            metrics.get("slot_wer"),
            metrics.get("ser"),
            metrics.get("oracle_wer"),
            metrics.get("nbest_match"),
            metrics.get("match_position"),
            processing_times.get("total"),
            len(data.get("errors", [])),
            critical_error_count,
            json.dumps(original_nbest) if original_nbest else None,
            json.dumps(normalized_nbest) if normalized_nbest else None,
            metrics.get("performance_optimized", False),
            json.dumps(data)
        ))
        
        conn.commit()
        conn.close()
    
    def _update_state(self, sample_index: int, data: dict):
        """Update experiment state for resume functionality."""
        self.state.last_processed_index = sample_index
        self.state.last_checkpoint_time = datetime.now().isoformat()
        
        # Update performance stats
        processing_time = data.get("processing_times", {}).get("total", 0)
        if processing_time > 0:
            if "avg_processing_time" not in self.state.performance_stats:
                self.state.performance_stats["avg_processing_time"] = processing_time
                self.state.performance_stats["sample_count"] = 1
            else:
                count = self.state.performance_stats["sample_count"]
                avg = self.state.performance_stats["avg_processing_time"]
                new_avg = (avg * count + processing_time) / (count + 1)
                self.state.performance_stats["avg_processing_time"] = new_avg
                self.state.performance_stats["sample_count"] = count + 1
    
    def _save_state(self):
        """Save current experiment state for resume functionality."""
        if self.output_format == "sqlite":
            self._save_state_sqlite()
        else:
            self._save_state_json()
    
    def _save_state_sqlite(self):
        """Save state to SQLite database."""
        conn = sqlite3.connect(self.output_file)
        cursor = conn.cursor()
        
        cursor.execute("""
            INSERT OR REPLACE INTO experiment_state (
                experiment_id, last_processed_index, total_samples, start_time,
                last_checkpoint_time, critical_errors, state_data
            ) VALUES (?, ?, ?, ?, ?, ?, ?)
        """, (
            self.state.experiment_id,
            self.state.last_processed_index,
            self.state.total_samples,
            self.state.start_time,
            self.state.last_checkpoint_time,
            self.critical_error_count,
            json.dumps(asdict(self.state))
        ))
        
        conn.commit()
        conn.close()
    
    def _save_state_json(self):
        """Save state to JSON file."""
        state_file = os.path.join(self.output_dir, f"{self.experiment_id}_state.json")
        with open(state_file, 'w') as f:
            json.dump(asdict(self.state), f, indent=2)
    
    def save_checkpoint(self, results: List[Dict[str, Any]], checkpoint_num: int):
        """Save checkpoint results with enhanced metadata."""
        checkpoint_dir = Path(f"{self.output_dir}/{self.experiment_id}/checkpoints")
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        checkpoint_file = checkpoint_dir / f"checkpoint_{checkpoint_num:04d}.jsonl"
        
        # Add checkpoint metadata
        checkpoint_metadata = {
            "checkpoint_number": checkpoint_num,
            "timestamp": datetime.now().isoformat(),
            "sample_count": len(results),
            "experiment_id": self.experiment_id,
            "performance_optimized": True,
            "critical_errors": self.critical_error_count
        }
        
        with open(checkpoint_file, 'w') as f:
            # Write metadata as first line
            f.write(json.dumps({"_checkpoint_metadata": checkpoint_metadata}) + '\n')
            
            # Write results
            for result in results:
                f.write(json.dumps(result, default=str) + '\n')
        
        # Save current state
        self._save_state()
        
        self.logger.info(f"💾 Checkpoint saved: {checkpoint_file} ({len(results)} samples)")
    
    def _flatten_dict(self, d: dict, parent_key: str = '', sep: str = '_') -> dict:
        """Flatten nested dictionary for CSV output."""
        items = []
        for k, v in d.items():
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
            if isinstance(v, dict):
                items.extend(self._flatten_dict(v, new_key, sep=sep).items())
            elif isinstance(v, list):
                # Convert lists to JSON strings for CSV
                items.append((new_key, json.dumps(v) if v else ''))
            else:
                items.append((new_key, v))
        return dict(items)
    
    def get_stats(self) -> dict:
        """Get enhanced experiment statistics."""
        return {
            **self.stats,
            "current_time": datetime.now().isoformat(),
            "output_file": self.output_file,
            "output_format": self.output_format,
            "checkpoint_interval": self.checkpoint_interval,
            "last_processed_index": self.state.last_processed_index,
            "critical_errors": self.critical_error_count,
            "performance_stats": self.state.performance_stats,
            "auto_resume_enabled": True
        }


def main():
    """Example usage of ASRExperimentLoggerImproved."""
    # Initialize enhanced logger
    logger = ASRExperimentLoggerImproved("test_experiment_improved", 
                                       output_format="jsonl", 
                                       checkpoint_interval=500)
    
    # Check for resume point
    resume_index = logger.detect_resume_point("test_experiment_improved")
    print(f"Resume index: {resume_index}")
    
    # Simulate dataset record
    dataset_record = {
        "utterance_id": "fin_412360_af_heart",
        "domain": "financial",
        "voice": "af_heart",
        "truth": "The company's Q3 revenue was $1.2M",
        "normalized_truth": "the company's q3 revenue was one point two million dollars",
        "asr_difficulty": 0.75,
        "audio": {"sampling_rate": 22050, "array": [0.1] * 70560}
    }
    
    # Log experiment with sample index
    sample_index = resume_index
    with logger.start_experiment(dataset_record, sample_index) as exp:
        # Step 1: ASR
        exp.log_asr_step(
            model_name="Qwen2.5-Audio-7B",
            model_params={"beam_size": 5, "temperature": 0.1},
            predictions=["the company's q3 revenue was 1.2 million", "the companies q3 revenue was one point two million"],
            confidence_scores=[0.95, 0.87]
        )
        
        # Step 2: Normalization
        exp.log_normalization_step(
            original_truth="The company's Q3 revenue was $1.2M",
            normalized_truth="the company's q3 revenue was one point two million dollars"
        )
        
        # Step 3: Mock optimized metrics
