import logging
import os
import re
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.nn as nn
from tqdm import tqdm
from config.sdk.config import Config  # Adjust based on actual config path

# -------------------------- Logging Configuration --------------------------
def setup_logging(cfg):
    """
    Setup logging with compact experiment ID + timestamp (ICML reproducibility standard).
    Experiment ID format: e2005s5d256t0820l15
    - e2005: Phase1 epochs (2-digit zero-padded) + Phase2 epochs (2-digit zero-padded)
    - s5: EBA reasoning steps
    - d256: Hidden dimension 
    - t0820: TRANSITION_CENTER*10 (2-digit) + ANNEALING_SLPOE*10 (2-digit)
    - l15: Initial Lagrangian multiplier
    
    Args:
        cfg: Config instance containing training hyperparameters
    Returns:
        logger: Initialized logger instance
        exp_id: Compact experiment ID string for traceability
    """
    # Generate compact experiment ID (unique for hyperparameter tracking)
    exp_id = (
        f"e{str(cfg.EPOCHS_PHASE1).zfill(2)}{str(cfg.EPOCHS_PHASE2).zfill(2)}"
        f"s{cfg.EBA_STEPS}"
        f"d{cfg.HIDDEN_DIM}"
        f"t{str(int(cfg.TRANSITION_CENTER*10)).zfill(2)}{str(int(cfg.ANNEALING_SLPOE*10)).zfill(2)}"
        f"l{cfg.LAMBDA_INIT}"
    )
    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    
    # Create log directory (ensure persistence)
    log_dir = os.path.join(cfg.LOG_DIR)
    os.makedirs(log_dir, exist_ok=True)
    
    # Log formatting (ICML standard: timestamp + level + message)
    log_format = "%(asctime)s - %(levelname)s - %(message)s"
    log_file = os.path.join(log_dir, f"tlad_training_{exp_id}_{time_str}.log")
    
    # Configure dual logging (file + console)
    logging.basicConfig(
        level=logging.INFO,
        format=log_format,
        handlers=[
            logging.FileHandler(log_file, encoding="utf-8"),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger("TLAD-Training")
    logger.info(f"Experiment ID initialized: {exp_id}")
    
    return logger, exp_id

# -------------------------- Model Checkpoint --------------------------
def save_checkpoint(model, optimizer, epoch, metrics, cfg, exp_id, is_best=False):
    """
    Save model checkpoint with experiment ID + timestamp (ICML reproducibility).
    Unwraps DataParallel for cross-environment compatibility.
    
    Args:
        model: Trained TLAD model (supports DataParallel)
        optimizer: Training optimizer
        epoch: Current training epoch
        metrics: Validation/test metrics (puzzle/cell accuracy)
        cfg: Config instance
        exp_id: Experiment ID string
        is_best: Whether to save as best model
    
    Returns:
        save_path: Path to saved checkpoint (for best model tracking)
    """
    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    checkpoint_dir = os.path.join(cfg.SAVE_DIR)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Base checkpoint path (exp_id + epoch + timestamp)
    save_path = os.path.join(checkpoint_dir, f"{exp_id}_epoch{epoch}_{time_str}.pth")
    
    # Unwrap DataParallel to avoid device mismatch
    raw_model = model.module if isinstance(model, nn.DataParallel) else model
    
    # Checkpoint content (full reproducibility: state dicts + config + metrics)
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": raw_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer else None,
        "metrics": metrics,
        "config": cfg,
        "exp_id": exp_id,
        "time_str": time_str
    }
    
    # Save checkpoint
    torch.save(checkpoint, save_path)
    logger = logging.getLogger("TLAD-Training")
    logger.info(f"Checkpoint saved to: {save_path}")
    
    # Save best model separately (for post-training visualization)
    if is_best:
        best_path = os.path.join(checkpoint_dir, f"{exp_id}_best_{time_str}.pth")
        torch.save(checkpoint, best_path)
        logger.info(f"Best model checkpoint saved to: {best_path}")
        save_path = best_path  # Return best path for tracking
    
    return save_path

def load_checkpoint(model, optimizer, load_path, cfg):
    """
    Load model checkpoint (compatible with DataParallel, ICML reproducibility).
    
    Args:
        model: TLAD model instance
        optimizer: Training optimizer (None for inference)
        load_path: Path to checkpoint file
        cfg: Config instance (device specification)
    
    Returns:
        epoch: Epoch number from checkpoint
        metrics: Validation metrics from checkpoint
    
    Raises:
        FileNotFoundError: If checkpoint path does not exist
    """
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"Checkpoint not found: {load_path}")
    
    # Load checkpoint to target device (avoid CUDA device mismatch)
    checkpoint = torch.load(load_path, map_location=cfg.DEVICE, weights_only=False)
    logger = logging.getLogger("TLAD-Training")
    
    # Log experiment metadata (critical for reproducibility)
    if "exp_id" in checkpoint:
        logger.info(f"Loaded checkpoint metadata - Exp ID: {checkpoint['exp_id']}, Epoch: {checkpoint['epoch']}")
    
    # Load model state dict (handle DataParallel wrapping)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint["model_state_dict"])
    
    # Load optimizer state dict (only if needed for resuming training)
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    return checkpoint["epoch"], checkpoint["metrics"]

# -------------------------- Visualization --------------------------
def visualize_tlad_dynamics(puzzle, outputs, epoch, cfg, exp_id):
    """
    Visualize TLAD core dynamics (energy descent, gradient strength, attention) for paper.
    Saves high-resolution (300 DPI) plots compatible with ICML formatting requirements.
    
    Args:
        puzzle: Input puzzle tensor (shape: [1, C, H, W])
        outputs: Model output dict (energy/grad_norm/A_0/A_final/probs)
        epoch: Current training epoch
        cfg: Config instance (VIS_DIR specification)
        exp_id: Experiment ID string
    """
    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    vis_dir = os.path.join(cfg.VIS_DIR)
    os.makedirs(vis_dir, exist_ok=True)
    
    # Extract core dynamics from model outputs
    energy = outputs['energy']
    grad_norms = outputs['grad_norm']
    A_0 = outputs['A_0'][0].detach().cpu().mean(dim=0).numpy()
    A_final = outputs['A_final'][0].detach().cpu().mean(dim=0).numpy()
    delta_heatmap = np.abs(A_final - A_0).sum(axis=1).reshape(9, 9)
    
    # Prediction grid (Sudoku 9x9 format)
    final_probs = outputs['probs'][-1][0].detach().cpu().numpy()
    pred_grid = final_probs.argmax(axis=-1).reshape(9, 9) + 1
    input_grid = puzzle[0].cpu().numpy().flatten().reshape(9, 9)
    
    # Create 2x2 subplot (ICML figure standard: 14x10 inches, 300 DPI)
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f"TLAD Dynamics - Epoch {epoch} (Exp ID: {exp_id})", fontsize=16, fontweight='bold')
    
    # 1. Energy descent curve (constraint optimization trajectory)
    axes[0,0].plot(energy, 'r-o', linewidth=2)
    axes[0,0].set_title("Energy Descent (Constraint Interactions)", fontweight='bold')
    axes[0,0].set_ylabel("Free Energy")
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. Gradient strength (log scale for dynamic range visualization)
    axes[0,1].plot(grad_norms, 'b-x', linewidth=1.5)
    axes[0,1].set_title("Gradient Strength", fontweight='bold')
    axes[0,1].set_yscale('log')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Delta Attention heatmap (System 2 correction magnitude)
    sns.heatmap(delta_heatmap, ax=axes[1,0], cmap="Reds", cbar=True)
    axes[1,0].set_title("System 2 Corrections (Delta Attention)", fontweight='bold')
    axes[1,0].axis('off')
    
    # 4. Prediction grid (Sudoku solution visualization)
    axes[1,1].imshow(np.zeros((9,9)), cmap='Greys', vmin=0, vmax=1)
    axes[1,1].axis('off')
    axes[1,1].set_title("Prediction (Blue=AI, Black=Input)", fontweight='bold')
    
    # Annotate prediction grid with Sudoku digits
    for i in range(9):
        for j in range(9):
            val = pred_grid[i, j]
            gt_val = input_grid[i, j]
            color = 'black' if gt_val != 0 else 'blue'
            weight = 'bold' if gt_val != 0 else 'normal'
            axes[1,1].text(j, i, str(val), color=color, fontweight=weight, 
                           ha='center', va='center', fontsize=12)
    
    # Save plot (ICML compliant: tight layout, 300 DPI, no extra whitespace)
    save_path = os.path.join(vis_dir, f"{exp_id}_epoch{epoch}_{time_str}.png")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()  # Release memory (critical for long training runs)
    
    logger = logging.getLogger("TLAD-Training")
    logger.info(f"Dynamics visualization saved to: {save_path}")

def plot_accuracy_curve(phase1_accs, phase2_accs, cfg, exp_id, curve_type="test"):
    """
    Plot puzzle accuracy evolution (validation/test) for ICML paper results.
    Generates publication-quality curves with phase transition annotation.
    
    Args:
        phase1_accs: Phase 1 (Perception) accuracy list (%)
        phase2_accs: Phase 2 (Thermodynamic) accuracy list (%)
        cfg: Config instance (VIS_DIR specification)
        exp_id: Experiment ID string
        curve_type: "val" (validation) or "test" (test) curve
    
    Raises:
        ValueError: If curve_type is not "val" or "test"
    """
    if curve_type not in ["val", "test"]:
        raise ValueError(f"Invalid curve type: {curve_type} (must be 'val' or 'test')")
    
    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    vis_dir = os.path.join(cfg.VIS_DIR)
    os.makedirs(vis_dir, exist_ok=True)
    
    # Save path (exp_id + curve type + timestamp for traceability)
    save_path = os.path.join(vis_dir, f"{exp_id}_{curve_type}_accuracy_evolution_{time_str}.png")
    
    # Create plot (ICML standard: 12x7 inches, 300 DPI)
    plt.figure(figsize=(12, 7))
    
    # Phase 1: Perception module pre-training
    phase1_epochs = np.arange(1, len(phase1_accs)+1)
    plt.plot(phase1_epochs, phase1_accs, 'b-o', label='Phase 1 (Perception)', 
             linewidth=2.5, markersize=8, alpha=0.8)
    
    # Phase 2: Thermodynamic reasoning training
    phase2_epochs = np.arange(len(phase1_accs)+1, len(phase1_accs)+len(phase2_accs)+1)
    plt.plot(phase2_epochs, phase2_accs, 'g-s', label='Phase 2 (Thermodynamic Reasoning)', 
             linewidth=2.5, markersize=8, alpha=0.8)
    
    # Phase transition vertical line (critical for result interpretation)
    plt.axvline(x=len(phase1_accs), color='r', linestyle='--', 
                linewidth=3, label='Phase Transition', alpha=0.9)
    
    # Plot formatting (ICML publication standards)
    set_type = "Validation" if curve_type == "val" else "Test"
    plt.xlabel('Epoch', fontsize=14, fontweight='bold')
    plt.ylabel('Puzzle Accuracy (%)', fontsize=14, fontweight='bold')
    plt.title(f'{set_type} Puzzle Accuracy Evolution: Perception → Thermodynamic Reasoning', 
              fontsize=16, fontweight='bold', pad=20)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12, loc='lower right')
    plt.ylim(bottom=0, top=100)
    plt.xticks(np.arange(0, len(phase1_accs)+len(phase2_accs)+2, 1), fontsize=10)
    plt.yticks(fontsize=10)
    
    # Save and clean up
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger = logging.getLogger("TLAD-Training")
    logger.info(f"{set_type} accuracy curve saved to: {save_path}")
    print(f"\n{set_type} accuracy curve saved to: {save_path}")

def visualize_best_model_step_by_step(model, correct_puzzle, cfg, exp_id, device):
    """
    Generate step-by-step visualization of best model reasoning (for ICML paper figures).
    Visualizes energy/gradient/attention dynamics at each reasoning step for a correct test sample.
    
    Args:
        model: Best trained TLAD model
        correct_puzzle: First correctly predicted test sample (shape: [1, C, H, W])
        cfg: Config instance (VIS_DIR specification)
        exp_id: Experiment ID string
        device: Computing device (cuda/cpu)
    """
    model.eval()
    logger = logging.getLogger("TLAD-Training")
    logger.info("Generating step-by-step visualization for best model (ICML paper)")
    
    # Dedicated directory for step-wise plots (organize paper figures)
    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    step_vis_dir = os.path.join(cfg.VIS_DIR, f"vis_{exp_id}_best_model_step_by_step")
    os.makedirs(step_vis_dir, exist_ok=True)
    
    with torch.no_grad():
        # Get step-wise reasoning outputs (critical for paper visualization)
        # FIX: Handle DataParallel wrapped models (access underlying module)
        if isinstance(model, nn.DataParallel):
            outputs = model.module(correct_puzzle, mode='reasoning', return_step_wise=True)
        else:
            outputs = model(correct_puzzle, mode='reasoning', return_step_wise=True)
        
        # Extract step-wise metrics
        num_steps = len(outputs['step_energy'])
        step_energies = outputs['step_energy']
        step_grad_norms = outputs['step_grad_norms']
        step_A_0 = outputs['step_A_0']
        step_A_final = outputs['step_A_final']
        
        # Final prediction reference
        final_probs = outputs['probs'][-1]
        pred_grid = final_probs.argmax(axis=-1).reshape(9, 9) + 1
        input_grid = correct_puzzle[0].cpu().numpy().flatten().reshape(9, 9)
        
        # Generate plot for each reasoning step
        for step in range(num_steps):
            fig, axes = plt.subplots(2, 2, figsize=(14, 10))
            fig.suptitle(f"TLAD Best Model - Step {step+1}/{num_steps} (Exp ID: {exp_id})", fontsize=16, fontweight='bold')
            
            # 1. Energy descent (up to current step)
            axes[0,0].plot(step_energies[:step+1], 'r-o', linewidth=2)
            axes[0,0].set_title(f"Energy Descent (Step {step+1})", fontweight='bold')
            axes[0,0].set_ylabel("Free Energy")
            axes[0,0].set_xlabel("Reasoning Step")
            axes[0,0].set_xticks(range(step+1))
            axes[0,0].set_xticklabels([f"{i+1}" for i in range(step+1)])
            axes[0,0].grid(True, alpha=0.3)
            
            # 2. Gradient strength (log scale, up to current step)
            axes[0,1].plot(step_grad_norms[:step+1], 'b-x', linewidth=1.5)
            axes[0,1].set_title(f"Gradient Strength (Step {step+1})", fontweight='bold')
            axes[0,1].set_ylabel("Gradient Norm (log)")
            axes[0,1].set_xlabel("Reasoning Step")
            axes[0,1].set_yscale('log')
            axes[0,1].set_xticks(range(step+1))
            axes[0,1].set_xticklabels([f"{i+1}" for i in range(step+1)])
            axes[0,1].grid(True, alpha=0.3)
            
            # 3. Delta Attention heatmap (current step)
            A_0_step = step_A_0[step][0].detach().cpu().mean(dim=0).numpy()
            A_final_step = step_A_final[step][0].detach().cpu().mean(dim=0).numpy()
            delta_heatmap = np.abs(A_final_step - A_0_step).sum(axis=1).reshape(9, 9)
            
            sns.heatmap(delta_heatmap, ax=axes[1,0], cmap="Reds", cbar=True)
            axes[1,0].set_title(f"Delta Attention (Step {step+1})", fontweight='bold')
            axes[1,0].axis('off')
            
            # 4. Current prediction grid
            axes[1,1].imshow(np.zeros((9,9)), cmap='Greys', vmin=0, vmax=1)
            axes[1,1].axis('off')
            axes[1,1].set_title(f"Prediction (Step {step+1})", fontweight='bold')
            
            # Step-wise prediction (fallback to final if not available)
            if 'step_probs' in outputs:
                step_probs = outputs['step_probs'][step][0].detach().cpu().numpy()
                step_pred_grid = step_probs.argmax(axis=-1).reshape(9, 9) + 1
            else:
                step_pred_grid = pred_grid
            
            # Annotate prediction grid
            for i in range(9):
                for j in range(9):
                    val = step_pred_grid[i, j]
                    gt_val = input_grid[i, j]
                    color = 'black' if gt_val != 0 else 'blue'
                    weight = 'bold' if gt_val != 0 else 'normal'
                    axes[1,1].text(j, i, str(val), color=color, fontweight=weight, 
                                   ha='center', va='center', fontsize=12)
            
            # Save step plot (ICML compliant)
            save_path = os.path.join(step_vis_dir, f"{exp_id}_best_model_step_{step+1}_{time_str}.png")
            plt.tight_layout()
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()
            logger.info(f"Step {step+1}/{num_steps} visualization saved to: {save_path}")
        
        # Generate summary plot (full reasoning process for paper)
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        fig.suptitle(f"TLAD Best Model - Full Reasoning Process (Exp ID: {exp_id})", fontsize=16, fontweight='bold')
        
        # Full energy descent curve
        axes[0].plot(step_energies, 'r-o', linewidth=2, markersize=8)
        axes[0].set_title("Full Energy Descent Curve", fontweight='bold')
        axes[0].set_ylabel("Free Energy")
        axes[0].set_xlabel("Reasoning Step")
        axes[0].set_xticks(range(num_steps))
        axes[0].set_xticklabels([f"{i+1}" for i in range(num_steps)])
        axes[0].grid(True, alpha=0.3)
        
        # Final delta attention heatmap
        final_A_0 = step_A_0[-1][0].detach().cpu().mean(dim=0).numpy()
        final_A_final = step_A_final[-1][0].detach().cpu().mean(dim=0).numpy()
        final_delta_heatmap = np.abs(final_A_final - final_A_0).sum(axis=1).reshape(9, 9)
        
        sns.heatmap(final_delta_heatmap, ax=axes[1], cmap="Reds", cbar=True)
        axes[1].set_title("Final Delta Attention (After All Steps)", fontweight='bold')
        axes[1].axis('off')
        
        # Save summary plot
        summary_path = os.path.join(step_vis_dir, f"{exp_id}_best_model_summary_{time_str}.png")
        plt.tight_layout()
        plt.savefig(summary_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Best model summary visualization saved to: {summary_path}")
        logger.info("Step-by-step visualization completed (ready for ICML paper)")

# -------------------------- Accuracy Calculation --------------------------
def check_accuracy_detailed(probs, gt):
    """
    Calculate core Sudoku accuracy metrics (ICML result reporting standard).
    
    Args:
        probs: Model prediction probabilities (shape: [B, 81, 9])
        gt: Ground truth solution (shape: [B, 81])
    
    Returns:
        puzzle_acc: Batch-level puzzle accuracy (0-1, 1 = full solution correct)
        cell_acc: Batch-level cell accuracy (0-1, average cell correctness)
    """
    preds = probs.argmax(dim=-1) + 1  # Convert logits to 1-9 digits
    cell_correct = (preds == gt)
    puzzle_acc = cell_correct.all(dim=1).float().mean().item()
    cell_acc = cell_correct.float().mean().item()
    
    return puzzle_acc, cell_acc

# -------------------------- Evaluation Utilities --------------------------
def validate(model, loader, device, is_test=False):
    """
    Evaluate model performance (GPU-optimized, ICML efficiency standard).
    
    Args:
        model: Trained TLAD model
        loader: Validation/test data loader
        device: Computing device (cuda/cpu)
        is_test: Whether evaluating on test set (progress bar labeling)
    
    Returns:
        avg_puzzle: Average puzzle accuracy (0-1)
        avg_cell: Average cell accuracy (0-1)
    """
    model.eval()
    puzzle_acc_sum = 0.0
    cell_acc_sum = 0.0
    count = 0
    
    # Disable gradients (critical for GPU memory efficiency)
    with torch.no_grad(), torch.autograd.profiler.record_function("evaluation"):
        loop_desc = "Testing" if is_test else "Validating"
        loop = tqdm(loader, desc=loop_desc, leave=False)
        
        for batch in loop:
            # Non-blocking GPU transfer (maximize parallelism)
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            solution = batch['solution'].to(device, non_blocking=True)
            
            # Model inference (reasoning mode)
            outputs = model(puzzle, mode='reasoning')
            
            # Compute accuracy metrics
            probs = torch.softmax(outputs['logits'], dim=-1)
            puzzle_acc, cell_acc = check_accuracy_detailed(probs, solution)
            
            # Accumulate metrics
            puzzle_acc_sum += puzzle_acc
            cell_acc_sum += cell_acc
            count += 1
            
            # Minimal progress bar update (reduce overhead)
            loop.set_postfix(puzzle_acc=f"{puzzle_acc:.1%}")
    
    # Calculate average metrics
    avg_puzzle = puzzle_acc_sum / count
    avg_cell = cell_acc_sum / count
    return avg_puzzle, avg_cell

def find_first_correct_test_sample(model, test_loader, device):
    """
    Find first correctly predicted test sample (for ICML paper visualization).
    
    Args:
        model: Best trained TLAD model
        test_loader: Test data loader
        device: Computing device (cuda/cpu)
    
    Returns:
        correct_puzzle: Correctly predicted puzzle tensor (shape: [1, C, H, W])
        correct_solution: Corresponding ground truth (shape: [1, 81])
    
    Raises:
        ValueError: If no correct test samples are found
    """
    model.eval()
    logger = logging.getLogger("TLAD-Training")
    logger.info("Searching for first correctly predicted test sample (paper visualization)")
    
    with torch.no_grad():
        loop = tqdm(test_loader, desc="Searching for correct test sample")
        for batch in loop:
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            solution = batch['solution'].to(device, non_blocking=True)
            
            # Inference
            outputs = model(puzzle, mode='reasoning')
            probs = torch.softmax(outputs['logits'], dim=-1)
            
            # Check each sample in batch
            for idx in range(puzzle.shape[0]):
                sample_probs = probs[idx:idx+1]
                sample_solution = solution[idx:idx+1]
                puzzle_acc, _ = check_accuracy_detailed(sample_probs, sample_solution)
                
                if puzzle_acc == 1.0:
                    logger.info(f"Found correct test sample - Batch: {loop.n}, Sample: {idx}")
                    return puzzle[idx:idx+1], sample_solution
    
    raise ValueError("No correctly predicted test samples found (adjust model or dataset)")

def measure_inference_time(model, test_loader, device, num_warmup=5, num_measure=20):
    """
    Measure inference time/throughput (GPU-synchronized, ICML efficiency reporting).
    
    Args:
        model: Trained TLAD model
        test_loader: Test data loader
        device: Computing device (cuda/cpu)
        num_warmup: Warmup iterations (avoid cold start bias)
        num_measure: Measurement iterations (statistical stability)
    
    Returns:
        avg_time_per_sample: Average inference time per sample (ms)
        throughput: Samples processed per second
    """
    model.eval()
    total_time = 0.0
    total_samples = 0
    
    # Warmup phase (initialize GPU kernels)
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= num_warmup:
                break
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            model(puzzle, mode='reasoning')
            if device == 'cuda':
                torch.cuda.synchronize()
    
    # Measurement phase (synchronized timing for accuracy)
    with torch.no_grad():
        loop = tqdm(test_loader, desc="Measuring inference time")
        for i, batch in enumerate(loop):
            if i >= num_measure:
                break
                
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            batch_size = puzzle.shape[0]
            total_samples += batch_size
            
            # Synchronize before/after measurement (critical for GPU timing)
            if device == 'cuda':
                torch.cuda.synchronize()
            start = time.perf_counter()
            
            model(puzzle, mode='reasoning')
            
            if device == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()
            
            total_time += (end - start)
    
    # Calculate metrics (convert to ms for readability)
    avg_time_per_sample = (total_time / total_samples) * 1000
    throughput = total_samples / total_time
    return avg_time_per_sample, throughput

# -------------------------- Reproducibility --------------------------
def seed_all(seed=1234):
    """
    Fix random seeds for full reproducibility (ICML mandatory requirement).
    
    Args:
        seed: Random seed (fixed at 1234 for consistency)
    """
    import random
    import numpy as np
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # Balance determinism/performance

def clean_state_dict(state_dict):
    pattern = re.compile(r'^module\.')
    cleaned_dict = {}
    for k, v in state_dict.items():
        cleaned_key = pattern.sub('', k)
        cleaned_dict[cleaned_key] = v
    return cleaned_dict

def calculate_sudoku_constraint_error(pred_grid):
    row_err, col_err, box_err = 0, 0, 0
    pred_grid = np.array(pred_grid)
    for i in range(9): row_err += 9 - len(np.unique(pred_grid[i, :]))
    for j in range(9): col_err += 9 - len(np.unique(pred_grid[:, j]))
    for bi in range(3):
        for bj in range(3):
            box = pred_grid[bi*3:(bi+1)*3, bj*3:(bj+1)*3]
            box_err += 9 - len(np.unique(box))
    return row_err + col_err + box_err, row_err, col_err, box_err