# -*- coding: utf-8 -*-
import logging
import os
import pickle
import time
import math
from pprint import pformat

import numpy as np
import torch
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm

import plotting

from experiments.models.cdtd.layers.mlp import MLP, TabDDPM_MLP
from experiments.models.cdtd.layers.train_utils import InverseSquareRootScheduler, LinearScheduler
from .utils import cycle

from experiments.models.cdtd.mixed_type_diffusion import MixedTypeDiffusion


from .experiment import Experiment
from .utils import get_total_trainable_params, set_seeds
import json

from data.data_utils import FastTensorDataLoader

@torch.no_grad()
def compute_cdtd_sensitivity(
    diff_model, x_cat, x_cont, y_cond, m_cat, m_cont, 
    epsilon=0.01, K=2, reduction="rms", device='cuda'
):
    """
    Output-based sensitivity for CDTD diffusion model:
    Perturb missing coordinates, measure output change on observed coordinates.
    
    Args:
        diff_model: MixedTypeDiffusion model
        x_cat: [B, num_cat_features] categorical input
        x_cont: [B, num_cont_features] continuous input  
        y_cond: conditional input (can be None)
        m_cat: [B, num_cat_features] categorical mask (1=observed, 0=missing)
        m_cont: [B, num_cont_features] continuous mask (1=observed, 0=missing)
        epsilon: perturbation scale (assumes features are normalized)
        K: number of random Rademacher directions to average
        reduction: "rms" or "l1"
        device: device to run on
        
    Returns:
        scalar float sensitivity measure
    """
    diff_model.eval()
    m_cat = m_cat.float() if m_cat is not None else None
    m_cont = m_cont.float() if m_cont is not None else None
    
    B = x_cat.shape[0] if x_cat is not None else x_cont.shape[0]
    num_cat = x_cat.shape[1] if x_cat is not None else 0
    num_cont = x_cont.shape[1] if x_cont is not None else 0
    
    # Count observed dimensions per sample
    k_obs_cat = m_cat.sum(dim=1).clamp_min(1.0) if (num_cat > 0 and m_cat is not None) else torch.zeros(B, device=device)
    k_obs_cont = m_cont.sum(dim=1).clamp_min(1.0) if (num_cont > 0 and m_cont is not None) else torch.zeros(B, device=device)
    k_obs_total = k_obs_cat + k_obs_cont
    
    sens_accum = 0.0
    
    # Get a fixed noise level for consistent comparison
    # Use a fixed u value (e.g., 0.5) to get consistent sigma
    u_fixed = torch.full((B,), 0.5, device=device, dtype=torch.float32)
    sigma = diff_model.timewarp_cdf(u_fixed, invert=True).detach().to(torch.float32)
    
    for _ in range(K):
        # Generate Rademacher noise and perturb inputs
        # For categorical: perturb embeddings directly (not indices)
        if num_cat > 0 and x_cat is not None:
            # First embed the categorical values
            x_cat_emb_base = diff_model.cat_emb(x_cat)  # [B, num_cat, dim]
            # Generate perturbation in embedding space (only for missing coordinates)
            # Scale by 1/sqrt(dim) to keep perturbation magnitude consistent with epsilon
            eta_cat = (torch.randint(0, 2, (B, num_cat, diff_model.dim), device=device).float() * 2 - 1.0)  # ±1
            # Apply mask: only perturb missing coordinates
            perturb_mask_cat = (1.0 - m_cat).unsqueeze(2)  # [B, num_cat, 1]
            # Scale perturbation by 1/sqrt(dim) so L2 norm per feature is ~epsilon, not epsilon*sqrt(dim)
            perturb_cat_emb = (epsilon / math.sqrt(diff_model.dim)) * eta_cat * perturb_mask_cat
            # Perturb embeddings
            x_cat_emb_plus = x_cat_emb_base + perturb_cat_emb
            x_cat_emb_minus = x_cat_emb_base - perturb_cat_emb
        else:
            x_cat_emb_plus = None
            x_cat_emb_minus = None
            
        if num_cont > 0 and x_cont is not None:
            eta_cont = (torch.randint(0, 2, (B, num_cont), device=device).float() * 2 - 1.0)  # ±1
            perturb_cont = epsilon * eta_cont * (1.0 - m_cont)
            # Perturb continuous (only missing coordinates)
            x_cont_plus = x_cont + perturb_cont
            x_cont_minus = x_cont - perturb_cont
        else:
            x_cont_plus = x_cont
            x_cont_minus = x_cont
        
        # FIX: Generate SHARED diffusion noise to ensure identical noise for both plus and minus
        # This prevents noise variance from dominating the sensitivity measurement
        sigma_cat = sigma[:, :diff_model.num_cat_features] if num_cat > 0 else None
        sigma_cont = sigma[:, diff_model.num_cat_features:] if num_cont > 0 else None
        
        if x_cat_emb_plus is not None:
            # Generate shared noise for categorical embeddings
            noise_cat = torch.randn_like(x_cat_emb_plus)  # [B, num_cat, dim]
            # Apply noise: x_t = x + sigma * noise (EDM/Karras formulation)
            x_cat_emb_t_plus = x_cat_emb_plus + noise_cat * sigma_cat.unsqueeze(2)
            x_cat_emb_t_minus = x_cat_emb_minus + noise_cat * sigma_cat.unsqueeze(2)
        else:
            x_cat_emb_t_plus = None
            x_cat_emb_t_minus = None
            
        if x_cont_plus is not None:
            # Generate shared noise for continuous
            noise_cont = torch.randn_like(x_cont_plus)  # [B, num_cont]
            # Apply noise: x_t = x + sigma * noise
            x_cont_t_plus = x_cont_plus + noise_cont * sigma_cont
            x_cont_t_minus = x_cont_minus + noise_cont * sigma_cont
        else:
            x_cont_t_plus = None
            x_cont_t_minus = None
        
        # Get model outputs via precondition
        if x_cat_emb_t_plus is not None:
            cat_logits_plus, cont_preds_plus = diff_model.precondition(
                x_cat_emb_t_plus, x_cont_t_plus, y_cond, u_fixed, sigma, m_cat, m_cont
            )
            cat_logits_minus, cont_preds_minus = diff_model.precondition(
                x_cat_emb_t_minus, x_cont_t_minus, y_cond, u_fixed, sigma, m_cat, m_cont
            )
        else:
            # Only continuous case
            cat_logits_plus, cont_preds_plus = diff_model.precondition(
                torch.zeros(B, 0, diff_model.dim, device=device), x_cont_t_plus, y_cond, u_fixed, sigma, None, m_cont
            )
            cat_logits_minus, cont_preds_minus = diff_model.precondition(
                torch.zeros(B, 0, diff_model.dim, device=device), x_cont_t_minus, y_cond, u_fixed, sigma, None, m_cont
            )
        
        # Compute output differences on observed coordinates
        # FIX: Use softmax for categorical to make it comparable to continuous (bounded to [0,1])
        diff_cat = torch.zeros(B, device=device)
        if num_cat > 0 and len(cat_logits_plus) > 0 and m_cat is not None:
            cat_diff_list = []
            for i in range(num_cat):
                if m_cat[:, i].any():  # Only if this feature has observed values
                    # Apply softmax to convert logits to probabilities (bounded [0,1])
                    prob_plus = torch.softmax(cat_logits_plus[i], dim=-1)  # [B, n_classes]
                    prob_minus = torch.softmax(cat_logits_minus[i], dim=-1)  # [B, n_classes]
                    prob_diff = (prob_plus - prob_minus)  # [B, n_classes]
                    # Apply mask: only count observed samples
                    prob_diff_masked = prob_diff * m_cat[:, i:i+1]  # [B, n_classes]
                    cat_diff_list.append((prob_diff_masked ** 2).sum(dim=1))  # [B]
            if cat_diff_list:
                diff_cat = torch.stack(cat_diff_list, dim=1).sum(dim=1)  # [B]
        
        # For continuous: use predictions directly
        diff_cont = torch.zeros(B, device=device)
        if num_cont > 0 and m_cont is not None:
            cont_diff = (cont_preds_plus - cont_preds_minus) * m_cont  # [B, num_cont]
            diff_cont = (cont_diff ** 2).sum(dim=1)  # [B]
        
        # Combine differences
        total_diff = diff_cat + diff_cont
        
        # Normalize by observed dimensions and epsilon
        if reduction == "rms":
            sens = (total_diff / k_obs_total).sqrt() / (2.0 * epsilon)
        elif reduction == "l1":
            sens = (total_diff.sqrt() / k_obs_total) / (2.0 * epsilon)
        else:
            raise ValueError("reduction must be 'rms' or 'l1'")
        
        sens_accum += sens.mean()
    
    return (sens_accum / K).item()

class Experiment_CDTD(Experiment):
    """Train and evaluate a continuous diffusion model for mixed-type tabular data."""

    def __init__(self, data_path, sample_path, config, exp_path, dataset, device, preproc, m_rounds=1, strategy=0, breaks=30000, run_name='', beta='0p7', use_log='switch'):
        super().__init__(data_path, sample_path, config, exp_path, dataset, device, preproc, run_name=run_name, beta=beta, use_log=use_log)
        self.strategy = strategy
        self.breaks = breaks
        self.m_rounds = m_rounds

    def get_model(self):
        args = self.config.model
        self.categories = self.data_wrangler.num_cats
        self.num_cat_features = self.data_wrangler.num_cat_features
        self.num_cont_features = self.data_wrangler.num_cont_features
        self.num_features = self.data_wrangler.num_total_features
        self.simulate_missings = self.config.training.get("simulate_missings", False)
        
        self.calibrate_losses = args.calibrate_losses

        if self.calibrate_losses:
            proportions = []
            n_sample = self.train_loader.X_cat.shape[0]
            for i in range(len(self.categories)):
                _, counts = self.train_loader.X_cat[:, i].unique(return_counts=True)
                proportions.append(counts / n_sample)
            self.proportions = proportions
        else:
            self.proportions = None

        if self.config.model.architecture == "mlp":
            score_model = MLP(
                self.num_cont_features,
                args.dim,
                self.categories,
                self.data_wrangler.num_y_classes,
                args.mlp_emb_dim,
                args.mlp_n_layers,
                args.mlp_n_units,
                proportions=self.proportions,
                use_fourier_features=args.use_fourier_features,
                act=args.act,
                feat_spec_cond=args.use_feat_spec_cond,
                time_fourier=args.use_time_fourier,
            )

        elif self.config.model.architecture == "tabddpm":
            
            score_model = TabDDPM_MLP(
                self.num_cont_features,
                args.dim,
                self.categories,
                self.data_wrangler.num_y_classes,
                args.mlp_emb_dim,
                args.mlp_n_layers,
                args.mlp_n_units,
                proportions=self.proportions,
                use_fourier_features=args.use_fourier_features,
            )
            
        #print(score_model)

        return MixedTypeDiffusion(
            model=score_model,
            dim=args.dim,
            categories=self.categories,
            num_features=self.num_features,
            task=self.data_wrangler.task,
            sigma_data_cat=args.sigma_data_cat,
            sigma_data_cont=args.sigma_data_cont,
            sigma_min_cat=args.sigma_min_cat,
            sigma_max_cat=args.sigma_max_cat,
            sigma_min_cont=args.sigma_min_cont,
            sigma_max_cont=args.sigma_max_cont,
            calibrate_losses=args.calibrate_losses,
            proportions=self.proportions,
            cat_emb_init_sigma=args.cat_emb_init_sigma,
            timewarp_variant=args.timewarp_variant,
            timewarp_type=args.timewarp_type,
            timewarp_weight_low_noise=args.timewarp_weight_low_noise,
            timewarp_bins=args.timewarp_bins,
            timewarp_decay=args.timewarp_decay,
            cat_bias=args.use_cat_bias,
            simulate_missings=self.simulate_missings,
            strategy = self.strategy, 
            breaks = self.breaks, 
        )

    def get_optimizer(self):
        config = self.config.optimizer

        if config.name == "adam":
            optimizer = torch.optim.Adam(self.diff_model.parameters(), **config.args)
            if hasattr(config, "gradient_clip_norm"):
                NotImplementedError()
        elif config.name == "adamw":
            optimizer = torch.optim.AdamW(self.diff_model.parameters(), **config.args)
        else:
            raise Exception("Unknown optimizer.")

        return optimizer

    def get_lr_scheduler(self):
        if self.config.training.scheduler == "linear":
            scheduler = LinearScheduler(
                self.config.training.num_steps_train,
                base_lr=self.config.optimizer.args.lr,
                final_lr=1e-6,
                warmup_steps=self.config.training.num_steps_lr_warmup,
                warmup_begin_lr=1e-6,
                anneal_lr=self.config.training.anneal_lr,
            )
        elif self.config.training.scheduler == "inverse_sqrt":
            scheduler = InverseSquareRootScheduler(
                base_lr=self.config.optimizer.args.lr,
                ref_step=self.config.training.ref_step,
                warmup_steps=self.config.training.num_steps_lr_warmup,
                anneal_lr=self.config.training.anneal_lr,
                warmup_begin_lr=1e-6,
            )
        else:
            scheduler = None

        return scheduler

    def train(self, **kwargs):
        plot_figures = kwargs.get("plot_figures", False)
        save_model = kwargs.get("save_model", False)
        self.train_loader = self.data_wrangler.get_train_loader(
            self.config.training.batch_size
        )
        
        train_iter = cycle(self.train_loader)
        
        self.test_loader = self.data_wrangler.get_train_loader(
            self.config.training.batch_size, partition = 'val'
        )


        logging.warning("=== Initializing model ===")
        set_seeds(self.seed, cuda_deterministic=True)
        self.diff_model = self.get_model().to(self.device)
        self.diff_model.train()
        logging.info(f"Trainable parameters: {int(get_total_trainable_params(self.diff_model)):,}")
        self.ema_diff_model = ExponentialMovingAverage(
            self.diff_model.parameters(), decay=self.config.optimizer.ema_decay
        )

        # initialize optimizer
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_lr_scheduler()

        logging.warning("=== Start training... ===")
        config = self.config.training
        
        # Validation configuration (matching TabDiff approach)
        validation_rounds = config.get('validation_interval', None)
        if validation_rounds is None:
            # Fallback to old calculation if not specified
            validation_rounds = round(max(1, self.m_rounds / 6) * config.num_steps_train // 100)
        else:
            # Ensure it's an integer
            validation_rounds = int(validation_rounds)
        
        max_patience = config.get('max_patience', None)
        if max_patience is None:
            # Fallback to old calculation: 15% of validation rounds
            total_rounds = config.num_steps_train // validation_rounds
            max_patience = round(total_rounds * 0.15)
        else:
            max_patience = int(max_patience)
        
        logging.warning(f'Validation configuration: validation_interval={validation_rounds}, max_patience={max_patience}')

        log_lst = []
        exp_name_ = self.sample_path.split('/')[-1][:-4]
        lst_dir = self.sample_path[:-(len(exp_name_)+4)] + exp_name_ + '.pt'
        
        # initialize logging
        self.current_step = 0
        
        batch_step = 0
        avg_loss1 = avg_loss2 = 0
        
        patience = 0
        total_norm = 1.0

        best_ema_params = None
        l_dict = None
        
        best_val_loss = curr_val_loss = 1e10
        best_gap = current_gap = 1e10
        best_step = self.current_step
        
        # Store sharpness values for volatility computation
        # Format: list of (step, sharpness) tuples
        # TEMPORARILY COMMENTED OUT
        # sharpness_history = []
        # Store current train loss used for sharpness computation (for consistent display)
        # self.current_train_loss_for_sharpness = None
        
        # Store sensitivity values for analysis
        # Format: list of (step, sensitivity) tuples
        sensitivity_history = []
        # Cache a batch at step 300 for consistent sensitivity measurements
        self.cached_batch = None
        
        if save_model:
            with open(os.path.join(self.logdir, "hparams.txt"), "w") as f:
                f.write(pformat(self.config))

        training_start_time = time.time()
        with tqdm(
            initial=self.current_step,
            total=config.num_steps_train,
            disable=(not save_model),
        ) as pbar:
            while self.current_step < config.num_steps_train:
                is_last_step = self.current_step == (config.num_steps_train - 1)
                self.optimizer.zero_grad()

                inputs = next(train_iter)
                x_cat, x_cont, m_cat, m_cont, y_cond = (
                    input.to(self.device) if input is not None else None
                    for input in inputs
                )

                losses, losses_obs = self.diff_model.loss_fn(x_cat, x_cont, y_cond, m_cat, m_cont)
                losses["train_loss"].backward()

                # During training loop, before optimizer.step()
                # TEMPORARILY COMMENTED OUT - Sharpness computation
                # if self.current_step >= 300 and (self.current_step) % 300 == 0:
                #     total_norm = 0.0
                #     for name, param in self.diff_model.model.named_parameters():
                #         if param.grad is not None:
                #             param_norm = param.grad.data.norm(2)
                #             total_norm += param_norm.item() ** 2
                #     total_norm = total_norm ** 0.5
                #     
                #     # Compute sharpness (gradient_norm / train_loss) and store it
                #     # Use current batch loss (losses_obs) - same as what's used in display
                #     if losses_obs is not None and losses_obs.item() > 0:
                #         current_train_loss = losses_obs.item()
                #         sharpness = total_norm / current_train_loss
                #         sharpness_history.append((self.current_step, sharpness))
                #         # Store current_train_loss for use in progress bar display
                #         self.current_train_loss_for_sharpness = current_train_loss
                
                # TEMPORARILY COMMENTED OUT - Sensitivity computation (not useful currently)
                # Compute sensitivity every 300 steps starting from step 300
                # if self.current_step >= 300 and (self.current_step) % 300 == 0:
                #     # Cache batch at step 300 for consistent measurements
                #     if self.current_step == 300:
                #         self.cached_batch = (
                #             x_cat.clone().detach() if x_cat is not None else None,
                #             x_cont.clone().detach() if x_cont is not None else None,
                #             y_cond.clone().detach() if y_cond is not None else None,
                #             m_cat.clone().detach() if m_cat is not None else None,
                #             m_cont.clone().detach() if m_cont is not None else None
                #         )
                #     
                #     # Use cached batch for sensitivity computation
                #     if self.cached_batch is not None:
                #         x_cat_cached, x_cont_cached, y_cond_cached, m_cat_cached, m_cont_cached = self.cached_batch
                #         try:
                #             sensitivity = compute_cdtd_sensitivity(
                #                 self.diff_model,
                #                 x_cat_cached, x_cont_cached, y_cond_cached,
                #                 m_cat_cached, m_cont_cached,
                #                 epsilon=0.05, K=2, reduction="rms", device=self.device
                #             )
                #             sensitivity_history.append((self.current_step, sensitivity))
                #             print(f"Step {self.current_step}: Sensitivity = {sensitivity:.6f}")
                #         except Exception as e:
                #             print(f"Warning: Failed to compute sensitivity at step {self.current_step}: {e}")
                self.optimizer.step()
                self.diff_model.timewarp_cdf.update_ema()
                self.ema_diff_model.update()

                train_dict = self.get_metric_dict(
                    losses["train_loss"],
                    losses["weighted_calibrated"],
                    losses["timewarping"],
                )
                
                batch_step += 1
                self.current_step += 1
                
                avg_loss1 += train_dict['total_train_loss']
                avg_loss2 += losses_obs

                pbar.update(1)
                if (batch_step * config.batch_size > self.train_loader.dataset_len):
                    with torch.no_grad():
                        batch_step += (batch_step * config.batch_size- self.train_loader.dataset_len)/config.batch_size
                        curr_loss1 = avg_loss1 / batch_step
                        curr_loss2 = avg_loss2 / batch_step
                        
                        
                        l_dict = { 'CDTD_loss': curr_loss1, 
                                 'DSM_loss(obs, train)':curr_loss2.item(), 
                                  'step': self.current_step}

                    batch_step = 0
                    avg_loss1 = 0
                    avg_loss2 = 0
                
                if (self.current_step) % (validation_rounds) == 0:
                    val_loss_lst = []
                    with torch.no_grad():
                        val_batch_count = 0
                        for idx, inputs_val in enumerate(self.test_loader):
                            x_cat_val, x_cont_val, m_cat_val, m_cont_val, y_cond_val = (
                                inputs_val.to(self.device) if inputs_val is not None else None
                                for inputs_val in inputs_val
                            )
                            
                            _, loss_masked_ = self.diff_model.loss_fn(x_cat_val, x_cont_val, y_cond_val, m_cat_val, m_cont_val)

                            val_loss_lst.append(loss_masked_)
                            val_batch_count += 1
                            
                            # Only check first few batches for debugging
                            if val_batch_count >= 3:
                                break
                    
                    if len(val_loss_lst) > 0:
                        current_val_loss = torch.mean(torch.tensor(val_loss_lst)).item()
                    else:
                        logging.warning("No validation batches processed for validation loss")
                        current_val_loss = float('nan')
                    
                    l_dict['DSM_loss(obs, val)'] = current_val_loss
                    current_gap = np.abs(current_val_loss - l_dict.get('DSM_loss(obs, train)', 'N/A'))
                    log_lst.append(l_dict)
                    
                    if current_val_loss < best_val_loss:
                    #if current_gap < best_gap:
                        best_gap = current_gap
                        patience = 0
                        best_val_loss = current_val_loss
                        best_step = self.current_step
                        self.ema_diff_model.store()
                        self.ema_diff_model.copy_to()
                        self.save_model()
                        self.ema_diff_model.restore()
                        logging.warning(f' {best_step} (gap: {best_gap:.4f}, val_loss: {best_val_loss:.4f})')
                        # Update progress bar description with validation info
                        train_loss_display = l_dict.get('DSM_loss(obs, train)', 'N/A') if l_dict else 'N/A'
                        # TEMPORARILY COMMENTED OUT - Sharpness display
                        # sharpness_loss = getattr(self, 'current_train_loss_for_sharpness', None)
                        # if sharpness_loss is None:
                        #     sharpness_loss = train_loss_display if train_loss_display != 'N/A' else 1.0
                        # sharpness_display = total_norm / sharpness_loss if sharpness_loss != 'N/A' and sharpness_loss > 0 else 0.0
                        pbar.set_description(
                            f"Step {self.current_step}/{config.num_steps_train} | "
                            f"Loss: {train_loss_display:.4f} | Val: {current_val_loss:.4f} "
                            f"Patience: {patience}/{max_patience}"
                        )
                    else:
                        patience += 1
                        train_loss_display = l_dict.get('DSM_loss(obs, train)', 'N/A') if l_dict else 'N/A'
                        # TEMPORARILY COMMENTED OUT - Sharpness display
                        # sharpness_loss = getattr(self, 'current_train_loss_for_sharpness', None)
                        # if sharpness_loss is None:
                        #     sharpness_loss = train_loss_display if train_loss_display != 'N/A' else 1.0
                        # sharpness_display = total_norm / sharpness_loss if sharpness_loss != 'N/A' and sharpness_loss > 0 else 0.0
                        pbar.set_description(
                            f"Step {self.current_step}/{config.num_steps_train} | "
                            f"Loss: {train_loss_display:.4f} | Val: {current_val_loss:.4f} "
                            f"Patience: {patience}/{max_patience}"
                        )
                
                if self.scheduler:
                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = self.scheduler(self.current_step)

                
                if patience >= max_patience: 
                    logging.warning(f'Early stopping at step {self.current_step} (patience {patience}/{max_patience})')
                    logging.warning(f'Best validation loss: {best_val_loss:.4f} at step {best_step}')
                    break

        training_duration = time.time() - training_start_time
        self.save_train_time(training_duration)
        
        torch.save(log_lst, lst_dir)
        logging.info(f"Best model (validation) at step {best_step}")

        # Compute volatility and related metrics for sensitivity (from last 25% of training steps)
        final_step = self.current_step
        volatility = None
        logscaleV = None
        mean_log_sensitivity = None
        q90_log_sensitivity = None
        q95_log_sensitivity = None
        cv_sensitivity = None
        sensitivity_last_25 = []
        if len(sensitivity_history) > 0:
            # Get last 25% of steps
            last_25_percent_start = int(final_step * 0.75)
            sensitivity_last_25 = [sens for step, sens in sensitivity_history if step >= last_25_percent_start]
            
            if len(sensitivity_last_25) > 1:
                sensitivity_array = np.array(sensitivity_last_25)
                volatility = np.var(sensitivity_array)
                
                # Compute metrics on log(S) - log stabilizes spikes
                # Only compute if all values are positive (log requires positive values)
                if np.all(sensitivity_array > 0):
                    log_sensitivity = np.log(sensitivity_array)
                    logscaleV = np.var(log_sensitivity)  # Var(log(S))
                    mean_log_sensitivity = np.mean(log_sensitivity)  # mean(log(S))
                    # Compute quantiles of log(S)
                    q90_log_sensitivity = np.percentile(log_sensitivity, 90)  # Q90(log(S))
                    q95_log_sensitivity = np.percentile(log_sensitivity, 95)  # Q95(log(S))
                else:
                    logscaleV = None
                    mean_log_sensitivity = None
                    q90_log_sensitivity = None
                    q95_log_sensitivity = None
                
                # Compute CV(S) = Std(S) / mean(S) - coefficient of variation
                mean_sensitivity = np.mean(sensitivity_array)
                if mean_sensitivity > 0:
                    cv_sensitivity = np.std(sensitivity_array) / mean_sensitivity
                else:
                    cv_sensitivity = None
                
                logging.info(f"Training dynamics: steps={final_step}, volatility={volatility:.6f}" + (f", logscaleV={logscaleV:.6f}" if logscaleV else ""))
            else:
                logging.warning(f"Not enough sensitivity measurements in last 25% (got {len(sensitivity_last_25)})")
        else:
            logging.warning("No sensitivity measurements collected during training")
        
        # Store metrics for potential use in evaluation
        self.volatility = volatility
        self.logscaleV = logscaleV
        self.mean_log_sensitivity = mean_log_sensitivity
        self.q90_log_sensitivity = q90_log_sensitivity
        self.q95_log_sensitivity = q95_log_sensitivity
        self.cv_sensitivity = cv_sensitivity
        
        # Save volatility and related metrics to a JSON file for evaluator to read
        volatility_file = lst_dir.replace('.pt', '_volatility.json')
        volatility_data = {
            'volatility': float(volatility) if volatility is not None else None,
            'logscaleV': float(logscaleV) if logscaleV is not None else None,
            'mean_log_sensitivity': float(mean_log_sensitivity) if mean_log_sensitivity is not None else None,
            'q90_log_sensitivity': float(q90_log_sensitivity) if q90_log_sensitivity is not None else None,
            'q95_log_sensitivity': float(q95_log_sensitivity) if q95_log_sensitivity is not None else None,
            'cv_sensitivity': float(cv_sensitivity) if cv_sensitivity is not None else None,
            'final_step': int(final_step),
            'num_sensitivity_measurements': len(sensitivity_history),
            'sensitivity_last_25_count': len(sensitivity_last_25)
        }
        with open(volatility_file, 'w') as f:
            json.dump(volatility_data, f, indent=2)
        
        # Save sensitivity history for time series logging in wandb
        sensitivity_history_file = lst_dir.replace('.pt', '_sensitivity_history.json')
        sensitivity_history_data = {
            'sensitivity_history': [{'step': int(step), 'sensitivity': float(sens)} for step, sens in sensitivity_history],
            'final_step': int(final_step)
        }
        with open(sensitivity_history_file, 'w') as f:
            json.dump(sensitivity_history_data, f, indent=2)

        self.diff_model.eval()
        logging.info("=== Finished model training. ===")
        
        # Compute validation loss variance statistics (for verification)
        if self.data_wrangler.preproc == 'LGB_S' and self.data_wrangler.val_prop > 0:
            try:
                self.compute_validation_loss_variance(K=30, random_seed=42)
            except Exception as e:
                print(f"⚠️  Warning: Failed to compute validation loss variance: {e}")
                import traceback
                traceback.print_exc()
        
        # Generate final sample at the end of training for quality evaluation
        try:
            self.ema_diff_model.copy_to()
            self.diff_model.eval()
            X_cat_gen, X_cont_gen, y_gen = self.sample_tabular_data(
                self.num_samples, seed=42, verbose=False
            )
            self.data_wrangler.save_data(X_cat_gen, X_cont_gen, y_gen, sample_path=self.sample_path)
            logging.info(f"Final sample saved to {self.sample_path}")
        except Exception as e:
            logging.warning(f"Failed to generate final sample: {e}")
            import traceback
            traceback.print_exc()

    def sample_tabular_data(self, num_samples, **kwargs):
        seed = kwargs.get("seed", None)
        verbose = kwargs.get("verbose", False)

        if seed:
            set_seeds(seed, cuda_deterministic=True)

        if self.config.model.y_cond:
            _, _, y_train = self.data_wrangler.data.get_train_data()
            _, y_counts = np.unique(y_train, return_counts=True)
            y_dist = y_counts / y_counts.sum()
        else:
            y_dist = None

        X_cat_gen, X_cont_gen, y_gen = self.diff_model.sample_data(
            num_samples,
            self.config.training.batch_size_eval,
            y_dist,
            self.config.model.generation_steps,
            verbose=verbose,
        )

        X_cat_gen, X_cont_gen, y_gen = self.data_wrangler.postprocess_gen_data(
            X_cat_gen.to(torch.long).numpy(),
            X_cont_gen.numpy(),
            y_gen.numpy() if y_gen is not None else None,
        )

        return X_cat_gen, X_cont_gen, y_gen

    def compute_validation_loss_variance(self, K=30, random_seed=42):
        """
        Compute validation loss statistics using Rademacher perturbations.
        For verification purposes - logs to wandb.
        
        Args:
            K: Number of Rademacher perturbations per row (default: 30)
            random_seed: Random seed for reproducibility (default: 42)
        """
        import pickle
        import wandb
        
        val_data = self.data_wrangler.data.get_val_data()
        if val_data[0] is None and val_data[1] is None:
            logging.warning("No validation set; skipping validation loss variance.")
            return
        
        X_cat_val, X_cont_val, M_cat_val, M_cont_val, y_val = val_data
        if X_cont_val is None or M_cont_val is None:
            logging.warning("No continuous features in validation set; skipping validation loss variance.")
            return
        
        n_val = X_cont_val.shape[0]
        num_cont_features = X_cont_val.shape[1]
        predictions_cache_path = f"{self.data_wrangler.data_path}/val_mu_std_hat_{self.data_wrangler.preproc}_beta{self.beta}_use_log{self.data_wrangler.use_log}.pkl"
        
        if not os.path.exists(predictions_cache_path):
            logging.warning(f"Predictions cache not found at {predictions_cache_path}; skipping.")
            return
        with open(predictions_cache_path, 'rb') as f:
            predictions_dict = pickle.load(f)
        
        # Reconstruct validation indices from train/val split
        # The split uses random_state=42, so we can reconstruct it
        from sklearn import model_selection
        train_data = self.data_wrangler.data.get_train_data()
        X_cat_train, X_cont_train = train_data[0], train_data[1]
        
        # Combine train and val to get original indices
        if X_cat_train is not None:
            n_train = X_cat_train.shape[0]
        else:
            n_train = X_cont_train.shape[0]
        
        # Reconstruct split indices (same logic as in data_prepm.py line 629-640)
        val_prop = self.data_wrangler.val_prop
        test_prop = self.data_wrangler.test_prop
        if val_prop > 0:
            prop = val_prop / (1 - test_prop)
            # Create dummy arrays for split
            dummy_array = np.arange(n_train + n_val)
            _, val_indices = model_selection.train_test_split(
                dummy_array,
                test_size=prop,
                random_state=42,
                shuffle=True
            )
            val_indices_in_combined = val_indices
        else:
            val_indices_in_combined = []
        
        # Map predictions to validation set rows
        # predictions_dict keys are column indices in the combined DataFrame
        # We need to match row indices from predictions to validation set rows
        
        # Get number of categorical features to map column indices correctly
        n_cat = self.data_wrangler.num_cat_features
        task = self.data_wrangler.task
        
        # Map combined DataFrame column index to X_cont_val column index
        def map_col_idx(combined_col_idx):
            """Map column index from combined DataFrame to X_cont_val."""
            if task == "regression":
                # Combined: [cat_cols..., y, cont_cols...]
                # X_cont_val: [y, cont_cols...]
                # So if combined_col_idx >= n_cat, it's in X_cont_val at position (combined_col_idx - n_cat)
                if combined_col_idx >= n_cat:
                    return combined_col_idx - n_cat
            else:
                # Classification: Combined: [y, cat_cols..., cont_cols...]
                # X_cont_val: [cont_cols...]
                # So if combined_col_idx >= n_cat + 1, it's in X_cont_val at position (combined_col_idx - n_cat - 1)
                if combined_col_idx >= n_cat + 1:
                    return combined_col_idx - n_cat - 1
            return None  # Not a continuous column
        
        # Create mapping: (val_row_idx, X_cont_col_idx) -> (mu, std_hat)
        predictions_map = {}  # (val_row_idx, X_cont_col_idx) -> (mu, std_hat)
        
        for combined_col_idx, pred_data in predictions_dict.items():
            row_indices_combined = pred_data['row_indices']  # Original DataFrame row indices
            mu_test = pred_data['mu_test']
            std_hat = pred_data['std_hat']
            
            # Map to X_cont_val column index
            x_cont_col_idx = map_col_idx(combined_col_idx)
            if x_cont_col_idx is None:
                continue  # Not a continuous column
            
            # Find which of these rows are in validation set
            for i, row_idx_combined in enumerate(row_indices_combined):
                if row_idx_combined in val_indices_in_combined:
                    # Map to validation set index
                    val_row_idx = np.where(val_indices_in_combined == row_idx_combined)[0][0]
                    predictions_map[(val_row_idx, x_cont_col_idx)] = (mu_test[i], std_hat[i])
        
        logging.info(f"Mapped {len(predictions_map)} predictions to validation set")
        
        # Generate K Rademacher perturbations
        # Use fixed random seed to ensure same Rademacher values every time
        # Save/load cache to ensure reproducibility across runs
        epsilon_cache_path = f"{self.data_wrangler.data_path}/val_rademacher_epsilon_K{K}_seed{random_seed}.pkl"
        epsilon_dict = {}  # (val_row_idx, col_idx, k) -> epsilon value
        
        if os.path.exists(epsilon_cache_path):
            logging.info(f"Loading cached Rademacher perturbations from {epsilon_cache_path}")
            with open(epsilon_cache_path, 'rb') as f:
                epsilon_dict = pickle.load(f)
            pass
        else:
            # Generate new Rademacher perturbations with fixed seed
            np.random.seed(random_seed)
            for val_row_idx in range(n_val):
                for col_idx in range(num_cont_features):
                    if M_cont_val[val_row_idx, col_idx] == 0:  # Missing value
                        if (val_row_idx, col_idx) in predictions_map:
                            for k in range(K):
                                epsilon = np.random.randint(0, 2) * 2 - 1  # ±1
                                epsilon_dict[(val_row_idx, col_idx, k)] = epsilon
            
            # Save cache for future runs
            cache_dir = os.path.dirname(epsilon_cache_path)
            if cache_dir:
                os.makedirs(cache_dir, exist_ok=True)
            with open(epsilon_cache_path, 'wb') as f:
                pickle.dump(epsilon_dict, f)
            logging.info(f"Cached {len(epsilon_dict)} Rademacher perturbations to {epsilon_cache_path}")
        
        # Load best model
        self.ema_diff_model.copy_to()
        self.diff_model.eval()
        
        # Compute losses in batches
        batch_size = 32  # Process validation rows in batches
        loss_matrix = np.zeros((n_val, K))  # [n_val, K]
        
        logging.info("Computing validation losses with perturbations...")
        # Process in batches: for each perturbation k, process all rows in batches
        for k in tqdm(range(K), desc="Processing perturbations"):
            for batch_start in range(0, n_val, batch_size):
                batch_end = min(batch_start + batch_size, n_val)
                batch_indices = np.arange(batch_start, batch_end)
                batch_size_actual = len(batch_indices)
                
                # Handle single-sample batch by duplicating (to avoid dimension issues in weight_network)
                is_padded = False
                if batch_size_actual == 1:
                    batch_indices = np.concatenate([batch_indices, batch_indices])  # Duplicate
                    batch_size_actual = 2
                    is_padded = True
                
                # Create perturbed validation samples for this batch
                x_cont_perturbed_batch = X_cont_val[batch_indices].copy()
                
                # Apply perturbations to missing continuous values
                for i, val_row_idx in enumerate(batch_indices):
                    for col_idx in range(num_cont_features):
                        if M_cont_val[val_row_idx, col_idx] == 0:  # Missing
                            key = (val_row_idx, col_idx, k)
                            if key in epsilon_dict and (val_row_idx, col_idx) in predictions_map:
                                mu_val, std_hat_val = predictions_map[(val_row_idx, col_idx)]
                                epsilon = epsilon_dict[key]
                                x_cont_perturbed_batch[i, col_idx] = mu_val + std_hat_val * epsilon
                
                # Convert to tensors
                x_cat_tensor = None
                if X_cat_val is not None:
                    x_cat_tensor = torch.tensor(X_cat_val[batch_indices], dtype=torch.long).to(self.device)
                
                x_cont_tensor = torch.tensor(x_cont_perturbed_batch, dtype=torch.float32).to(self.device)
                
                m_cat_tensor = None
                if M_cat_val is not None:
                    m_cat_tensor = torch.tensor(M_cat_val[batch_indices], dtype=torch.float32).to(self.device)
                
                m_cont_tensor = torch.tensor(M_cont_val[batch_indices], dtype=torch.float32).to(self.device)
                
                y_cond_tensor = None  # y is already in X_cat or X_cont
                
                # Compute loss
                with torch.no_grad():
                    losses_dict, _ = self.diff_model.loss_fn(
                        x_cat_tensor, x_cont_tensor, y_cond_tensor, 
                        m_cat_tensor, m_cont_tensor
                    )
                    # Extract per-sample losses from weighted_calibrated
                    # weighted_calibrated has shape [batch, num_features]
                    # We need to compute (input_mask * weighted_calibrated).sum(dim=1) / input_mask.sum(dim=1)
                    weighted_calibrated = losses_dict["weighted_calibrated"]
                    input_mask = torch.cat([m_cat_tensor if m_cat_tensor is not None else torch.zeros(batch_size_actual, 0, device=self.device),
                                           m_cont_tensor if m_cont_tensor is not None else torch.zeros(batch_size_actual, 0, device=self.device)], dim=1)
                    # Compute per-sample masked loss
                    per_sample_loss = (input_mask * weighted_calibrated).sum(dim=1) / input_mask.sum(dim=1).clamp_min(1e-8)
                    loss_values = per_sample_loss.cpu().numpy()
                    # Only save losses for the original samples (skip duplicate if padded)
                    original_size = batch_end - batch_start
                    for i in range(original_size):
                        val_row_idx = batch_start + i
                        loss_matrix[val_row_idx, k] = loss_values[i]
        
        # Compute statistics
        mean_loss_per_row = np.mean(loss_matrix, axis=1)  # [n_val]
        var_within_row = np.var(loss_matrix, axis=1, ddof=1)  # [n_val] (using ddof=1 for sample variance)
        
        mean_loss = np.mean(mean_loss_per_row)
        V_within = np.mean(var_within_row)
        V_between = np.var(mean_loss_per_row, ddof=1)  # Sample variance
        V_total = V_within + V_between
        V_eff = V_between + V_within / K  # Effective variance (v_hat)
        
        # Compute Bernstein bound score
        n = n_val
        H = 1  # Single beta per run (not sweeping)
        delta = 0.05
        delta_prime = delta / H
        v_hat = V_eff  # V_between + V_within / K
        B = np.max(loss_matrix)  # Loss clip bound (max observed loss)
        logterm = np.log(2.0 / delta_prime)
        C = np.sqrt(2.0 * v_hat * logterm / n) + (7.0 * B * logterm) / (3.0 * (n - 1))
        score = mean_loss + C
        
        logging.info(f"Validation loss variance: mean={mean_loss:.4f}, V_total={V_total:.4f}, score={score:.4f}")
        
        # Save statistics to JSON file for evaluator to read
        # Save next to sample CSV file (same directory)
        val_loss_variance_file = self.sample_path.replace('.csv', '_val_loss_variance.json')
        val_loss_variance_data = {
            'mean_loss': float(mean_loss),
            'V_within': float(V_within),
            'V_between': float(V_between),
            'V_total': float(V_total),
            'V_eff': float(V_eff),
            'loss_mean_plus_V_total': float(mean_loss + V_total),
            'C': float(C),
            'score': float(score),
            'B': float(B),
            'n_val': int(n_val),
            'K': int(K),
            'delta': float(delta),
            'H': int(H)
        }
        with open(val_loss_variance_file, 'w') as f:
            json.dump(val_loss_variance_data, f, indent=2)
        logging.info(f"Saved validation loss variance to {val_loss_variance_file}")

    def get_metric_dict(self, diff_loss, weighted_losses, timewarp_losses):
        scalar_dict = {
            "total_train_loss": diff_loss.detach().mean().item(),
            "timewarp/total_loss": timewarp_losses.detach().mean().item(),
        }

        if weighted_losses is not None:
            scalar_dict["avg_weighted_loss"] = weighted_losses.detach().mean().item()

        return scalar_dict

    def log_fn(self, writer, step, train_dict):
        for metric_name, metric_value in train_dict.items():
            writer.add_scalar(
                "train/{}".format(metric_name), metric_value, global_step=step
            )

    def save_model(self):
        checkpoint = {
            "current_step": self.current_step,
            "diff_model": self.diff_model.state_dict(),
            "data_wrangler": self.data_wrangler,
            "train_loader": self.train_loader,
        }
        torch.save(checkpoint, os.path.join(self.ckpt_restore_dir, "model.pt"))

    
    def load_model(self):
        checkpoint = torch.load(os.path.join(self.ckpt_restore_dir, "model.pt"))
        self.current_step = checkpoint["current_step"]
        self.train_loader = checkpoint["train_loader"]
        
        # Store original data_wrangler values before replacing (for comparison)
        original_data_wrangler = self.data_wrangler if hasattr(self, 'data_wrangler') else None
        original_num_cats = None
        original_num_cont_features = None
        if original_data_wrangler is not None:
            if hasattr(original_data_wrangler, 'num_cats'):
                original_num_cats = original_data_wrangler.num_cats
            if hasattr(original_data_wrangler, 'num_cont_features'):
                original_num_cont_features = original_data_wrangler.num_cont_features
            # print(f'DEBUG [load_model]: Original data_wrangler (from __init__):')
            # print(f'  num_cats: {original_num_cats}')
            # print(f'  num_cont_features: {original_num_cont_features}')
        
        # Replace with checkpoint's data_wrangler
        self.data_wrangler = checkpoint["data_wrangler"]
        
        # Check checkpoint's expected output size
        checkpoint_bias_size = checkpoint["diff_model"]["model.final_layer.linear.bias"].shape[0]
        #print(f'DEBUG [load_model]: Checkpoint model bias size: {checkpoint_bias_size}')
        
        
        # Store checkpoint_num_cats for later verification
        checkpoint_num_cats = None
        if hasattr(self.data_wrangler, 'num_cats') and self.data_wrangler.num_cats is not None:
            checkpoint_num_cats = self.data_wrangler.num_cats
            #print(f'  num_cats: {checkpoint_num_cats}')
            
            # Compare original vs checkpoint data_wrangler
            checkpoint_num_cont = self.data_wrangler.num_cont_features
            checkpoint_categories_with_unk = [c + 1 for c in checkpoint_num_cats]
            sum_categories_with_unk = sum(checkpoint_categories_with_unk)
            checkpoint_expected_size_from_data_wrangler = sum_categories_with_unk + checkpoint_num_cont
            
            # Compare with original
            if original_num_cats is not None and original_num_cats != checkpoint_num_cats:
                print(f'  ⚠️  WARNING: Original num_cats ({original_num_cats}) differs from checkpoint num_cats ({checkpoint_num_cats})!')
                print(f'      This suggests categorical encoding changed (likely due to different n_cached).')
                print(f'      The checkpoint model was likely trained with n_cached>1 (more categories).')
                print(f'      Will verify if original num_cats works with checkpoint model...')
                
                original_categories_with_unk = [c + 1 for c in original_num_cats]
                original_sum_categories_with_unk = sum(original_categories_with_unk)
                inferred_num_cont_from_original = checkpoint_bias_size - original_sum_categories_with_unk
                
                print(f'  Checking if original num_cats works with checkpoint model:')
                print(f'    Original sum categories_with_unk: {original_sum_categories_with_unk}')
                print(f'    Inferred num_cont_features using original num_cats: {inferred_num_cont_from_original}')
                
                if inferred_num_cont_from_original >= 0:
                    print(f'  ✅ Original num_cats works with checkpoint model! Using original num_cats.')
                    print(f'      This ensures we use the correct categorical encoding for current n_cached.')
                    # Update checkpoint data_wrangler to use original num_cats
                    self.data_wrangler.num_cats = original_num_cats
                    # Update checkpoint_num_cats for later verification
                    checkpoint_num_cats = original_num_cats
                    # Recalculate num_cat_features if needed
                    if hasattr(self.data_wrangler, 'num_cat_features'):
                        self.data_wrangler.num_cat_features = len(original_num_cats)
                    # Update num_cont_features to match what the model expects
                    self.data_wrangler.num_cont_features = inferred_num_cont_from_original
                    # Update num_total_features
                    if hasattr(self.data_wrangler, 'num_total_features'):
                        self.data_wrangler.num_total_features = (
                            self.data_wrangler.num_cat_features + inferred_num_cont_from_original
                        )
                    if inferred_num_cont_from_original != original_num_cont_features:
                        print(f'      Note: Updated num_cont_features to {inferred_num_cont_from_original} (was {original_num_cont_features}) to match checkpoint model.')
                    
                    # Verify the update
                    updated_sum_cats = sum([c + 1 for c in original_num_cats])
                    updated_expected_size = updated_sum_cats + inferred_num_cont_from_original
                    print(f'      Verification: Updated data_wrangler expects {updated_expected_size} features (should match checkpoint {checkpoint_bias_size})')
                    if updated_expected_size == checkpoint_bias_size:
                        print(f'      ✅ Verification passed: Updated data_wrangler matches checkpoint model size.')
                    else:
                        print(f'      ⚠️  WARNING: Updated data_wrangler still doesn\'t match! This is unexpected.')
                else:
                    print(f'  ⚠️  Original num_cats does not work with checkpoint model.')
                    print(f'      Will use checkpoint num_cats (but this may cause issues if checkpoint was saved incorrectly).')
            
            if original_num_cont_features is not None and original_num_cont_features != checkpoint_num_cont:
                print(f'  ⚠️  WARNING: Original num_cont_features ({original_num_cont_features}) differs from checkpoint ({checkpoint_num_cont})!')
                print(f'      Using checkpoint num_cont_features (or inferred value if fixed below).')
            
            # Check if checkpoint data_wrangler matches checkpoint model (CRITICAL FIX)
            # This is the source of truth - the model size is what matters
            if checkpoint_expected_size_from_data_wrangler != checkpoint_bias_size:
                print(f'  ⚠️  WARNING: Checkpoint data_wrangler does not match checkpoint model size!')
                print(f'      Checkpoint data_wrangler expects {checkpoint_expected_size_from_data_wrangler} features,')
                print(f'      but checkpoint model has {checkpoint_bias_size} features.')
                print(f'      This suggests the checkpoint data_wrangler is inconsistent with the checkpoint model.')
                print(f'      Attempting to fix by inferring correct feature counts from checkpoint model...')
                
                # FIRST: Try using original data_wrangler's num_cats if available (it might be correct)
                # This handles the case where checkpoint was saved with wrong num_cats
                tried_original = False
                if original_num_cats is not None and original_num_cats != checkpoint_num_cats:
                    original_categories_with_unk = [c + 1 for c in original_num_cats]
                    original_sum_categories_with_unk = sum(original_categories_with_unk)
                    inferred_num_cont_from_original = checkpoint_bias_size - original_sum_categories_with_unk
                    
                    print(f'  Trying original data_wrangler num_cats first: {original_num_cats}')
                    print(f'  Original sum categories_with_unk: {original_sum_categories_with_unk}')
                    print(f'  Inferred num_cont_features using original num_cats: {inferred_num_cont_from_original}')
                    
                    if inferred_num_cont_from_original >= 0:
                        print(f'  ✅ Original num_cats works with checkpoint model! Using original num_cats.')
                        # Update checkpoint data_wrangler to use original num_cats
                        self.data_wrangler.num_cats = original_num_cats
                        # Update checkpoint_num_cats for later verification
                        checkpoint_num_cats = original_num_cats
                        # Recalculate num_cat_features if needed
                        if hasattr(self.data_wrangler, 'num_cat_features'):
                            self.data_wrangler.num_cat_features = len(original_num_cats)
                        # Update num_cont_features to match what the model expects
                        self.data_wrangler.num_cont_features = inferred_num_cont_from_original
                        # Update num_total_features
                        if hasattr(self.data_wrangler, 'num_total_features'):
                            self.data_wrangler.num_total_features = (
                                self.data_wrangler.num_cat_features + inferred_num_cont_from_original
                            )
                        if inferred_num_cont_from_original != original_num_cont_features:
                            print(f'  Note: Updated num_cont_features to {inferred_num_cont_from_original} (was {original_num_cont_features}) to match checkpoint model.')
                        print(f'  Updated checkpoint data_wrangler to use original num_cats (computed with current n_cached).')
                        
                        # Verify the update
                        updated_sum_cats = sum([c + 1 for c in original_num_cats])
                        updated_expected_size = updated_sum_cats + inferred_num_cont_from_original
                        print(f'  Verification: Updated data_wrangler expects {updated_expected_size} features (should match checkpoint {checkpoint_bias_size})')
                        if updated_expected_size != checkpoint_bias_size:
                            print(f'  ⚠️  WARNING: Updated data_wrangler still doesn\'t match! This is unexpected.')
                        else:
                            print(f'  ✅ Verification passed: Updated data_wrangler matches checkpoint model size.')
                        tried_original = True
                
                # If original num_cats didn't work, try to infer num_cont_features assuming num_cats is correct
                if not tried_original:
                    inferred_num_cont_features = checkpoint_bias_size - sum_categories_with_unk
                    
                    if inferred_num_cont_features < 0:
                        # This means the checkpoint's num_cats is wrong (categorical encoding mismatch)
                        # But we already tried original num_cats above, so if we're here, neither worked
                        raise RuntimeError(
                            f"Cannot fix categorical encoding mismatch: checkpoint model size ({checkpoint_bias_size}) "
                            f"is smaller than sum of categories_with_unk ({sum_categories_with_unk}). "
                            f"Checkpoint num_cats: {checkpoint_num_cats}, Original num_cats: {original_num_cats if original_num_cats else 'N/A'}. "
                            f"This suggests the model was trained with a different n_cached value than what's currently configured. "
                            f"Please ensure the current run uses the same n_cached value as the training run."
                        )
                    else:
                        # num_cont_features mismatch (not num_cats) - fix it
                        print(f'  Inferred num_cont_features from checkpoint model: {inferred_num_cont_features}')
                        print(f'  (Checkpoint data_wrangler had: {checkpoint_num_cont})')
                        
                        if inferred_num_cont_features != checkpoint_num_cont:
                            print(f'  ✅ Mismatch is in num_cont_features - updating checkpoint data_wrangler...')
                            # Update data_wrangler's num_cont_features to match checkpoint model
                            self.data_wrangler.num_cont_features = inferred_num_cont_features
                            # Also update num_total_features if it exists
                            if hasattr(self.data_wrangler, 'num_total_features'):
                                old_total = self.data_wrangler.num_total_features
                                self.data_wrangler.num_total_features = (
                                    self.data_wrangler.num_cat_features + inferred_num_cont_features
                                )
                                print(f'  Updated num_total_features: {self.data_wrangler.num_total_features} (was {old_total})')
                    # num_cont_features mismatch (not num_cats)
                    print(f'  Inferred num_cont_features from checkpoint model: {inferred_num_cont_features}')
                    print(f'  (Checkpoint data_wrangler had: {checkpoint_num_cont})')
                    
                    if inferred_num_cont_features != checkpoint_num_cont:
                        print(f'  ✅ Mismatch is in num_cont_features - updating checkpoint data_wrangler...')
                        # Update data_wrangler's num_cont_features to match checkpoint model
                        self.data_wrangler.num_cont_features = inferred_num_cont_features
                        # Also update num_total_features if it exists
                        if hasattr(self.data_wrangler, 'num_total_features'):
                            old_total = self.data_wrangler.num_total_features
                            self.data_wrangler.num_total_features = (
                                self.data_wrangler.num_cat_features + inferred_num_cont_features
                            )
                            print(f'  Updated num_total_features: {self.data_wrangler.num_total_features} (was {old_total})')
        else:
            print(f'  ⚠️  WARNING: Checkpoint data_wrangler has no num_cats attribute!')
        
        # Store checkpoint_num_cats for verification after model creation
        if checkpoint_num_cats is None and hasattr(self.data_wrangler, 'num_cats') and self.data_wrangler.num_cats is not None:
            checkpoint_num_cats = self.data_wrangler.num_cats
            checkpoint_num_cats = self.data_wrangler.num_cats
            checkpoint_num_cont = self.data_wrangler.num_cont_features
            checkpoint_categories_with_unk = [c + 1 for c in checkpoint_num_cats]
            sum_categories_with_unk = sum(checkpoint_categories_with_unk)
            checkpoint_expected_size_from_data_wrangler = sum_categories_with_unk + checkpoint_num_cont
            
            # Check if checkpoint data_wrangler matches checkpoint model
            if checkpoint_expected_size_from_data_wrangler != checkpoint_bias_size:
                print(f'  ⚠️  WARNING: Checkpoint data_wrangler does not match checkpoint model size!')
                print(f'      The checkpoint data_wrangler is inconsistent with the checkpoint model.')
                print(f'      Attempting to fix by inferring correct feature counts from checkpoint model...')
                
                # First, try to infer num_cont_features assuming num_cats is correct
                inferred_num_cont_features = checkpoint_bias_size - sum_categories_with_unk
                
                if inferred_num_cont_features < 0:
                    # This means the checkpoint's num_cats is wrong (categorical encoding mismatch)
                    print(f'  ⚠️  Categorical encoding mismatch detected!')
                    print(f'      Checkpoint model size ({checkpoint_bias_size}) < sum of categories_with_unk ({sum_categories_with_unk})')
                    print(f'      This suggests checkpoint data_wrangler.num_cats is incorrect.')
                    print(f'      Attempting to infer correct num_cats from checkpoint model...')
                    
                    # Try using original data_wrangler's num_cats if available
                    if original_num_cats is not None:
                        original_categories_with_unk = [c + 1 for c in original_num_cats]
                        original_sum_categories_with_unk = sum(original_categories_with_unk)
                        inferred_num_cont_from_original = checkpoint_bias_size - original_sum_categories_with_unk
                        
                        print(f'  Trying original data_wrangler num_cats: {original_num_cats}')
                        print(f'  Original sum categories_with_unk: {original_sum_categories_with_unk}')
                        print(f'  Inferred num_cont_features using original num_cats: {inferred_num_cont_from_original}')
                        
                        if inferred_num_cont_from_original >= 0 and inferred_num_cont_from_original == original_num_cont_features:
                            print(f'  ✅ Original num_cats matches checkpoint model! Using original num_cats.')
                            # Update checkpoint data_wrangler to use original num_cats
                            self.data_wrangler.num_cats = original_num_cats
                            # Update checkpoint_num_cats for later verification
                            checkpoint_num_cats = original_num_cats
                            # Recalculate num_cat_features if needed
                            if hasattr(self.data_wrangler, 'num_cat_features'):
                                self.data_wrangler.num_cat_features = len(original_num_cats)
                            # Update num_cont_features
                            self.data_wrangler.num_cont_features = original_num_cont_features
                            # Update num_total_features
                            if hasattr(self.data_wrangler, 'num_total_features'):
                                self.data_wrangler.num_total_features = (
                                    self.data_wrangler.num_cat_features + original_num_cont_features
                                )
                            print(f'  Updated checkpoint data_wrangler to use original num_cats and num_cont_features.')
                        else:
                            # Try to infer by checking if difference is exactly 1 (one category more)
                            size_diff = checkpoint_bias_size - checkpoint_expected_size_from_data_wrangler
                            if size_diff == 1:
                                print(f'  ⚠️  Size difference is exactly 1 - one categorical feature has 1 more category.')
                                print(f'      However, cannot determine which feature changed.')
                                print(f'      Attempting to infer num_cont_features assuming one category was added...')
                                # If one category was added, we need to find which one
                                # For now, try to infer num_cont_features by assuming the difference is in categories
                                # and that num_cont_features should match the original
                                if original_num_cont_features is not None:
                                    # Check if using original num_cont_features would work
                                    needed_sum_cats = checkpoint_bias_size - original_num_cont_features
                                    if needed_sum_cats > sum_categories_with_unk:
                                        print(f'  ⚠️  Cannot automatically fix categorical encoding mismatch.')
                                        print(f'      Checkpoint model expects {needed_sum_cats} categorical features,')
                                        print(f'      but checkpoint data_wrangler has {sum_categories_with_unk}.')
                                        print(f'      Difference: {needed_sum_cats - sum_categories_with_unk} categories.')
                                        raise RuntimeError(
                                            f"Cannot fix categorical encoding mismatch: checkpoint model size ({checkpoint_bias_size}) "
                                            f"requires {needed_sum_cats} categorical features, but checkpoint data_wrangler has "
                                            f"{sum_categories_with_unk}. This suggests num_cats changed between save and load. "
                                            f"Original num_cats: {original_num_cats}, Checkpoint num_cats: {checkpoint_num_cats}. "
                                            f"Please retrain the model with current preprocessing code."
                                        )
                            else:
                                raise RuntimeError(
                                    f"Cannot fix checkpoint inconsistency: checkpoint model size ({checkpoint_bias_size}) "
                                    f"is smaller than sum of categories_with_unk ({sum_categories_with_unk}). "
                                    f"This suggests the checkpoint data_wrangler's num_cats is wrong. "
                                    f"Original num_cats: {original_num_cats}, Checkpoint num_cats: {checkpoint_num_cats}. "
                                    f"Please retrain the model with current preprocessing code."
                                )
                    else:
                        raise RuntimeError(
                            f"Cannot fix checkpoint inconsistency: checkpoint model size ({checkpoint_bias_size}) "
                            f"is smaller than sum of categories_with_unk ({sum_categories_with_unk}). "
                            f"This suggests the checkpoint data_wrangler's num_cats is wrong, "
                            f"and no original data_wrangler is available for comparison. "
                            f"Please retrain the model with current preprocessing code."
                        )
                else:
                    # num_cont_features mismatch (not num_cats)
                    print(f'  Inferred num_cont_features from checkpoint model: {inferred_num_cont_features}')
                    print(f'  (Checkpoint data_wrangler had: {checkpoint_num_cont})')
                    
                    if inferred_num_cont_features != checkpoint_num_cont:
                        print(f'  ✅ Mismatch is in num_cont_features - updating checkpoint data_wrangler...')
                        # Update data_wrangler's num_cont_features to match checkpoint model
                        self.data_wrangler.num_cont_features = inferred_num_cont_features
                        # Also update num_total_features if it exists
                        if hasattr(self.data_wrangler, 'num_total_features'):
                            old_total = self.data_wrangler.num_total_features
                            self.data_wrangler.num_total_features = (
                                self.data_wrangler.num_cat_features + inferred_num_cont_features
                            )
                            print(f'  Updated num_total_features: {self.data_wrangler.num_total_features} (was {old_total})')
            
            # Also compare with original data_wrangler to see if there's a discrepancy
            if original_num_cats is not None and original_num_cats != checkpoint_num_cats:
                print(f'  ⚠️  WARNING: Original num_cats ({original_num_cats}) differs from checkpoint num_cats ({checkpoint_num_cats})!')
                print(f'      This suggests categorical encoding changed. Using checkpoint num_cats.')
            
        
        # Final verification and fix before creating model: ALWAYS try original num_cats if it differs
        # This is critical because the checkpoint's data_wrangler might be saved incorrectly
        if hasattr(self.data_wrangler, 'num_cats') and self.data_wrangler.num_cats is not None:
            final_sum_cats = sum([c + 1 for c in self.data_wrangler.num_cats])
            final_expected_size = final_sum_cats + self.data_wrangler.num_cont_features
            print(f'  Current data_wrangler expected size: {final_expected_size}')
            
            # ALWAYS try original num_cats if it differs, even if checkpoint appears to match
            # This handles the case where checkpoint was saved with wrong num_cats but happens to match
            if original_num_cats is not None:
                # Check if they're actually different (list comparison)
                if original_num_cats != self.data_wrangler.num_cats:
                    original_sum_cats = sum([c + 1 for c in original_num_cats])
                    inferred_num_cont = checkpoint_bias_size - original_sum_cats
                    print(f'  🔧 FINAL FIX: Original num_cats differs from checkpoint, trying it...')
                    print(f'      Original num_cats: {original_num_cats} (sum: {original_sum_cats})')
                    print(f'      Checkpoint num_cats: {self.data_wrangler.num_cats} (sum: {final_sum_cats})')
                    print(f'      Inferred num_cont: {inferred_num_cont}')
                    
                    if inferred_num_cont >= 0:
                        # Use original num_cats - it's more likely to be correct for current n_cached
                        print(f'      ✅ Using original num_cats with inferred num_cont_features={inferred_num_cont}')
                        self.data_wrangler.num_cats = original_num_cats
                        self.data_wrangler.num_cont_features = inferred_num_cont
                        if hasattr(self.data_wrangler, 'num_cat_features'):
                            self.data_wrangler.num_cat_features = len(original_num_cats)
                        if hasattr(self.data_wrangler, 'num_total_features'):
                            self.data_wrangler.num_total_features = (
                                self.data_wrangler.num_cat_features + inferred_num_cont
                            )
                        
                        # Update checkpoint_num_cats for verification
                        checkpoint_num_cats = original_num_cats
                        
                        # Verify
                        verify_sum = sum([c + 1 for c in self.data_wrangler.num_cats])
                        verify_size = verify_sum + self.data_wrangler.num_cont_features
                        print(f'      Verification: {verify_sum} + {self.data_wrangler.num_cont_features} = {verify_size} (should be {checkpoint_bias_size})')
                        if verify_size != checkpoint_bias_size:
                            print(f'      ⚠️  WARNING: Verification failed!')
                    else:
                        print(f'      ⚠️  Cannot use original num_cats: inferred_num_cont would be negative ({inferred_num_cont})')
                else:
                    print(f'  Original num_cats same as checkpoint num_cats, no fix needed')
            else:
                print(f'  ⚠️  WARNING: original_num_cats is None, cannot apply fix!')
            
            # Final check after potential fix
            final_sum_cats = sum([c + 1 for c in self.data_wrangler.num_cats])
            final_expected_size = final_sum_cats + self.data_wrangler.num_cont_features
            print(f'  Final data_wrangler expected size: {final_expected_size} (checkpoint: {checkpoint_bias_size})')
            if final_expected_size != checkpoint_bias_size:
                print(f'  ⚠️  FINAL WARNING: data_wrangler still doesn\'t match checkpoint model!')
                print(f'      Final num_cats: {self.data_wrangler.num_cats}')
                print(f'      Final num_cont_features: {self.data_wrangler.num_cont_features}')
                print(f'      This should have been fixed above. There may be a bug in the fix logic.')
        
        # Set debug flag to enable detailed logging in get_model
        self._debug_load_model = True
        self.diff_model = self.get_model()
        delattr(self, '_debug_load_model')
        
        # Verify we used checkpoint's num_cats
        if checkpoint_num_cats is not None:
            if hasattr(self, 'categories') and self.categories != checkpoint_num_cats:
                print(f'  ⚠️  ERROR: Model created with categories={self.categories}, but checkpoint has {checkpoint_num_cats}!')
        
        # Check model architecture
        current_bias_size = self.diff_model.model.final_layer.linear.bias.shape[0]
        current_out_features = self.diff_model.model.final_layer.linear.out_features
        
        # Use out_features as the source of truth (bias size might be wrong due to a bug)
        # But also check bias size for compatibility
        model_size = current_out_features  # Prefer out_features over bias size
        if current_bias_size != current_out_features:
            print(f'  ⚠️  WARNING: Model bias size ({current_bias_size}) != out_features ({current_out_features})!')
            print(f'      Using out_features ({current_out_features}) as the model size.')
        
        if model_size != checkpoint_bias_size:
            # Model was created with wrong size - infer correct num_cont_features from checkpoint
            print(f'  🔧 Model size mismatch detected! Checkpoint has {checkpoint_bias_size}, model has {current_bias_size}')
            print(f'  🔧 Inferring correct num_cont_features from checkpoint model size...')
            
            if hasattr(self.data_wrangler, 'num_cats') and self.data_wrangler.num_cats is not None:
                current_num_cats = self.data_wrangler.num_cats
                current_sum_cats = sum([c + 1 for c in current_num_cats])
                inferred_num_cont = checkpoint_bias_size - current_sum_cats
                
                print(f'  🔧 Current num_cats: {current_num_cats} (sum with +1: {current_sum_cats})')
                print(f'  🔧 Inferred num_cont_features: {inferred_num_cont} (checkpoint: {checkpoint_bias_size} - {current_sum_cats})')
                
                if inferred_num_cont >= 0:
                    print(f'  🔧 POST-MODEL FIX: Updating num_cont_features to {inferred_num_cont} (was {self.data_wrangler.num_cont_features})')
                    # Delete the incorrectly created model
                    del self.diff_model
                    # Update data_wrangler's num_cont_features
                    self.data_wrangler.num_cont_features = inferred_num_cont
                    if hasattr(self.data_wrangler, 'num_total_features'):
                        self.data_wrangler.num_total_features = (
                            self.data_wrangler.num_cat_features + inferred_num_cont
                        )
                    # Recreate the model with correct architecture
                    print(f'  🔧 Recreating model with correct architecture...')
                    self._debug_load_model = True
                    self.diff_model = self.get_model()
                    delattr(self, '_debug_load_model')
                    # Verify the fix worked
                    new_bias_size = self.diff_model.model.final_layer.linear.bias.shape[0]
                    new_out_features = self.diff_model.model.final_layer.linear.out_features
                    print(f'  🔧 New model bias size: {new_bias_size} (should be {checkpoint_bias_size})')
                    print(f'  🔧 New model out_features: {new_out_features} (should be {checkpoint_bias_size})')
                    print(f'  🔧 Debug: linear layer weight shape: {self.diff_model.model.final_layer.linear.weight.shape}')
                    print(f'  🔧 Debug: linear layer bias shape: {self.diff_model.model.final_layer.linear.bias.shape}')
                    
                    # Check if there's a mismatch between out_features and bias size
                    if new_out_features != new_bias_size:
                        print(f'  ⚠️  WARNING: Model out_features ({new_out_features}) != bias size ({new_bias_size})!')
                        print(f'      This suggests a bug in model creation. Using out_features as the correct size.')
                        # Use out_features as the source of truth
                        if new_out_features == checkpoint_bias_size:
                            print(f'  ✅ Fix successful! Model out_features matches checkpoint.')
                            current_bias_size = new_out_features
                        else:
                            print(f'  ⚠️  Fix failed! Model out_features ({new_out_features}) != checkpoint ({checkpoint_bias_size})')
                    elif new_bias_size == checkpoint_bias_size:
                        print(f'  ✅ Fix successful! Model now has correct architecture.')
                        # Update current_bias_size so we skip the error below
                        current_bias_size = new_bias_size
                    else:
                        print(f'  ⚠️  Fix failed! New model still has wrong size.')
                else:
                    print(f'  ⚠️  Cannot fix: inferred_num_cont would be negative ({inferred_num_cont})')
            
            # Only raise error if fix didn't work
            # Re-check model size after fix attempt
            if hasattr(self, 'diff_model') and self.diff_model is not None:
                model_size = self.diff_model.model.final_layer.linear.out_features
            else:
                model_size = current_bias_size
            
            if model_size != checkpoint_bias_size:
                print(f'ERROR [load_model]: Size mismatch detected!')
                print(f'  Current model expects {current_bias_size} features')
                print(f'  Checkpoint has {checkpoint_bias_size} features')
                print(f'  This usually means num_cats calculation changed between save and load.')
                print(f'  Solution: Retrain the model with current preprocessing code, or')
                print(f'            ensure checkpoint was saved with same preprocessing as current code.')
                raise RuntimeError(
                    f"Model architecture mismatch: checkpoint has {checkpoint_bias_size} output features, "
                    f"but current model architecture expects {current_bias_size}. "
                    f"This usually means the data_wrangler's num_cats changed between save and load. "
                    f"Check if preprocessing or categorical encoding changed. "
                    f"Checkpoint num_cats: {self.data_wrangler.num_cats if hasattr(self.data_wrangler, 'num_cats') else 'N/A'}"
                )
        
        self.diff_model.load_state_dict(checkpoint["diff_model"])
        self.diff_model.to(self.device)
        self.diff_model.eval()
        print(f'Loading model at step:{self.current_step}')