# spar_jax.py - PRODUCTION READY VERSION (WITH MLP BRANCH)
import os
import math
import yaml

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import pickle
from dataclasses import dataclass, asdict
from typing import Any, List, Optional, Tuple

import d4rl
import gym
import numpy as np
import pyrallis
from tqdm import tqdm

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax import struct
import wandb

from iql_jax import (
    Critic,
    GaussianPolicy,
    DeterministicPolicy,
    ReplayBuffer,
    normalize_states,
    wrap_env,
    set_seed,
    modify_reward,
    load_checkpoint as load_iql_checkpoint,
)

from run_utils import wandb_init, find_latest_checkpoint

jax.config.update("jax_enable_x64", False)
jax.config.update("jax_default_matmul_precision", "high")

TensorBatch = List[jnp.ndarray]


# ==================== Config ====================
@dataclass
class SPARConfig:
    env: str = "halfcheetah-medium-expert-v2"
    seed: int = 0
    base_model_path: str = ""

    # CVAE
    cvae_hidden_dims: Tuple[int, ...] = (256, 256)
    cvae_latent_dim: int = 16
    cvae_layernorm: bool = False
    cvae_lr: float = 3e-4
    cvae_steps: int = 1_000_000
    batch_size: int = 256

    # Loss Weights
    kl_weight: float = 0.5
    recon_weight: float = 1.0

    # Guide Mode
    guide_mode: str = "plas"  # ["proj", "grad", "plas", "trainz", "mlp"]
    guide_weight: float = 0.1
    uncertainty_weight: float = 1.0
    guide_warmup_steps: int = 0
    guide_ramp_steps: int = 0     # 之后 M 步线性爬到 guide_weight
    phase2_start_step: int = 1_000_000  # Phase2 起点（会在 train() 里写入）

    # Advantage Weighting
    use_soft_filtering: bool = False
    advantage_temperature: float = 1.0
    weight_clip_max: float = 100.0
    relative_advantage_scale: float = 1.0

    # Latent Policy (PLAS/TrainZ)
    latent_policy_hidden_dims: Tuple[int, ...] = (256, 256)
    latent_policy_lr: float = 3e-4
    latent_max: float = 2.0

    # MLP Residual (NEW)
    mlp_residual_hidden_dims: Tuple[int, ...] = (256, 256)
    mlp_residual_lr: float = 3e-4

    # TrainZ Specific
    trainz_kl_beta: float = 1e-3
    trainz_logstd_min: float = -5.0
    trainz_logstd_max: float = 2.0
    trainz_num_samples: int = 1

    # Proj Specific
    proj_num_candidates: int = 64
    proj_period: int = 10
    proj_temperature: float = 1.0
    proj_use_relative_gain: bool = True
    proj_only_positive: bool = True
    proj_sigma: float = 1.0
    proj_use_target_proposal: bool = True
    target_tau: float = 0.005
    target_update_period: int = 1

    # Grad Specific
    use_adaptive_guide: bool = False
    guide_weight_clip: float = 10.0
    guide_weight_eps: float = 1e-6

    # Speed Knobs
    data_num_critics: int = 2
    guide_num_critics: int = 2
    proj_num_critics: int = 2
    rect_num_critics: int = 2

    # Inference
    num_candidates: int = 10
    use_rectification: bool = True
    rectification_abs_threshold: float = 1e-4
    rectification_rel_threshold: float = 0.01
    use_dual_threshold: bool = True
    rectification_smooth_eps: float = 1.0

    # Evaluation
    eval_freq: int = 50_000
    n_episodes: int = 20

    # Logging
    log_dir: str = "runs"
    checkpoints_path: Optional[str] = None
    log_freq: int = 1000
    project: str = "SPAR"
    group: str = "SPAR"

    # Buffer
    buffer_size: int = 2_000_000
    normalize: bool = True
    normalize_reward: bool = False

    def __post_init__(self):
        if self.guide_mode == "proj":
            if self.log_freq % self.proj_period != 0:
                new_freq = math.ceil(self.log_freq / self.proj_period) * self.proj_period
                print(f"[Config Warning] log_freq ({self.log_freq}) 不是 proj_period ({self.proj_period}) 的倍数。")
                print(f"                 已自动调整 log_freq -> {new_freq}")
                self.log_freq = new_freq


# ==================== Networks ====================
class CVAE(nn.Module):
    action_dim: int
    hidden_dims: Tuple[int, ...] = (256, 256)
    latent_dim: int = 16
    use_layernorm: bool = False

    def setup(self):
        self.encoder_layers = [nn.Dense(dim) for dim in self.hidden_dims]
        self.encoder_norms = [nn.LayerNorm() for _ in self.hidden_dims] if self.use_layernorm else None
        self.encoder_mean = nn.Dense(self.latent_dim)
        self.encoder_logvar = nn.Dense(self.latent_dim)

        self.decoder_layers = [nn.Dense(dim) for dim in self.hidden_dims]
        self.decoder_norms = [nn.LayerNorm() for _ in self.hidden_dims] if self.use_layernorm else None
        self.decoder_out = nn.Dense(
            self.action_dim,
            kernel_init=nn.initializers.zeros,
            bias_init=nn.initializers.zeros,
        )

    # 🔥 修改点：增加 base_action 输入
    def __call__(self, state, base_action, residual, rng, training: bool = False):
        # Encoder: Condition on State + BaseAction + Residual
        x = jnp.concatenate([state, base_action, residual], axis=-1)
        for i, dense in enumerate(self.encoder_layers):
            x = dense(x)
            if self.encoder_norms is not None:
                x = self.encoder_norms[i](x)
            x = nn.relu(x)

        mean = self.encoder_mean(x)
        raw = self.encoder_logvar(x)
        std = jax.nn.softplus(raw) + 1e-4
        std = jnp.clip(std, 1e-4, 20.0)
        logvar = 2.0 * jnp.log(std)
        eps = jax.random.normal(rng, mean.shape)
        z = mean + eps * std

        recon = self._decode_internal(state, base_action, z)
        return recon, mean, logvar

    # 🔥 修改点：增加 base_action 输入
    def _decode_internal(self, state, base_action, z):
        # Decoder: Condition on State + BaseAction + Z
        x = jnp.concatenate([state, base_action, z], axis=-1)
        for i, dense in enumerate(self.decoder_layers):
            x = dense(x)
            if self.decoder_norms is not None:
                x = self.decoder_norms[i](x)
            x = nn.relu(x)
        residual = self.decoder_out(x)
        return 2.0 * jnp.tanh(residual)

    # 🔥 修改点：增加 base_action 输入
    def decode(self, state, base_action, z):
        return self._decode_internal(state, base_action, z)



class LatentPolicy(nn.Module):
    latent_dim: int
    hidden_dims: Tuple[int, ...] = (256, 256)
    is_stochastic: bool = False
    logstd_min: float = -5.0
    logstd_max: float = 2.0
    latent_max: float = 2.0

    def setup(self):
        self.layers = [nn.Dense(dim) for dim in self.hidden_dims]
        if self.is_stochastic:
            self.mu_layer = nn.Dense(self.latent_dim)
            self.logstd_layer = nn.Dense(self.latent_dim)
        else:
            self.z_layer = nn.Dense(self.latent_dim)

    # 🔥 修改点：增加 base_action 输入
    def __call__(self, obs: jnp.ndarray, base_action: jnp.ndarray, training: bool = False):
        x = jnp.concatenate([obs, base_action], axis=-1)
        for layer in self.layers:
            x = nn.relu(layer(x))

        if self.is_stochastic:
            mu = self.mu_layer(x)
            logstd = self.logstd_layer(x)
            logstd = jnp.clip(logstd, self.logstd_min, self.logstd_max)
            return mu, logstd
        else:
            raw_z = self.z_layer(x)
            z = self.latent_max * jnp.tanh(raw_z)
            return z


class MLPResidual(nn.Module):
    """Deterministic MLP that directly predicts residual from state."""
    action_dim: int
    hidden_dims: Tuple[int, ...] = (256, 256)

    def setup(self):
        self.layers = [nn.Dense(dim) for dim in self.hidden_dims]
        self.output_layer = nn.Dense(
            self.action_dim,
            kernel_init=nn.initializers.zeros,
            bias_init=nn.initializers.zeros,
        )

    # 🔥 修改点：增加 base_action 输入
    def __call__(self, state: jnp.ndarray, base_action: jnp.ndarray, training: bool = False) -> jnp.ndarray:
        x = jnp.concatenate([state, base_action], axis=-1)
        for layer in self.layers:
            x = nn.relu(layer(x))
        residual = self.output_layer(x)
        return 2.0 * jnp.tanh(residual)


# ==================== State ====================
@struct.dataclass
class SPARState:
    cvae_params: Any
    cvae_target_params: Any
    cvae_opt_state: optax.OptState
    latent_policy_params: Any
    latent_policy_opt_state: optax.OptState
    mlp_residual_params: Any  # NEW
    mlp_residual_opt_state: optax.OptState  # NEW
    actor_params: Any
    q_params: Any
    total_it: jnp.ndarray


# ==================== Utils ====================
def cvae_kl_divergence(mean: jnp.ndarray, logvar: jnp.ndarray) -> jnp.ndarray:
    return 0.5 * jnp.sum(jnp.exp(logvar) + mean**2 - 1.0 - logvar, axis=-1)


def gaussian_kl_to_std_normal(mu: jnp.ndarray, logstd: jnp.ndarray) -> jnp.ndarray:
    sigma2 = jnp.exp(2.0 * logstd)
    return 0.5 * jnp.sum(mu**2 + sigma2 - 1.0 - jnp.log(sigma2 + 1e-8), axis=-1)


def _subset_critics(q_params, n: int):
    if n is None or n <= 0:
        return q_params
    return jax.tree_util.tree_map(lambda x: x[:n], q_params)


def compute_ensemble_q_robust(critic_def, q_params, state, action, uncertainty_weight):
    q_all = jax.vmap(lambda p: critic_def.apply({"params": p}, state, action))(q_params)
    q_mean = jnp.mean(q_all, axis=0)
    q_std = jnp.std(q_all, axis=0)
    q_robust = q_mean - uncertainty_weight * q_std
    return q_robust, q_mean, q_std


def _get_base_action(actor_def, actor_params, obs, deterministic_actor):
    if deterministic_actor:
        a_base = actor_def.apply({"params": actor_params}, obs, training=False)
    else:
        mean, _ = actor_def.apply({"params": actor_params}, obs, training=False)
        a_base = mean
    return jax.lax.stop_gradient(a_base)


def _decode_residual(cvae_def, cvae_params, obs, base_action, z):
    return cvae_def.apply({"params": cvae_params}, obs, base_action, z, method=cvae_def.decode)


def _safe_softmax(logits: jnp.ndarray, axis: int = -1) -> jnp.ndarray:
    logits = logits - jnp.max(logits, axis=axis, keepdims=True)
    exp = jnp.exp(logits)
    denom = jnp.sum(exp, axis=axis, keepdims=True) + 1e-8
    return exp / denom


# ==================== Training Step ====================
def make_spar_train_step(
    cvae_def: CVAE,
    latent_policy_def: LatentPolicy,
    mlp_residual_def: MLPResidual,
    critic_def: Critic,
    actor_def,
    cvae_tx: optax.GradientTransformation,
    latent_policy_tx: optax.GradientTransformation,
    mlp_residual_tx: optax.GradientTransformation,
    *,
    config: SPARConfig,
    deterministic_actor: bool,
):
    guide_mode = config.guide_mode

    @jax.jit
    def train_step(state: SPARState, batch: TensorBatch, rng: jax.Array):
        observations, actions, rewards, next_observations, dones = batch
        keys = jax.random.split(rng, 5)
        rng_out, key_encode, key_trainz, key_proj, key_misc = keys

        a_base = _get_base_action(actor_def, state.actor_params, observations, deterministic_actor)
        residual_gt = actions - a_base

        q_params_data = _subset_critics(state.q_params, config.data_num_critics)
        B = observations.shape[0]
        obs2 = jnp.concatenate([observations, observations], axis=0)
        act2 = jnp.concatenate([actions, a_base], axis=0)

        q2, _, _ = compute_ensemble_q_robust(
            critic_def, q_params_data, obs2, act2, config.uncertainty_weight
        )
        q_real = q2[:B]
        q_base_data = q2[B:]

        adv_rel = q_real - q_base_data

        # 分母稳定：避免 base Q 很小时放大
        adv_scale = jax.lax.stop_gradient(jnp.maximum(jnp.mean(jnp.abs(q_base_data)), 1.0))

        # 先做相对优势（不含温度），用于符号判断/过滤
        adv_raw = (adv_rel / adv_scale) * config.relative_advantage_scale

        # 在 log-space 里除温度再 clip，避免 exp overflow
        log_w = adv_raw / config.advantage_temperature
        log_w = jnp.clip(log_w, -20.0, 20.0)
        w_data = jnp.exp(log_w)

        if not config.use_soft_filtering:
            # 用 adv_raw 判断正负（与温度无关）
            w_data = jnp.where(adv_raw > 0, w_data, 0.0)

        w_data = jnp.minimum(w_data, config.weight_clip_max)
        w_data = jax.lax.stop_gradient(w_data)
        data_mask = jax.lax.stop_gradient(adv_raw > 0)


        step_it = state.total_it + jnp.array(1, dtype=jnp.int32)
        phase_step = jnp.maximum(
            step_it - jnp.asarray(config.phase2_start_step, dtype=jnp.int32),
            jnp.asarray(0, dtype=jnp.int32),
        )

        warm = jnp.asarray(config.guide_warmup_steps, dtype=jnp.int32)
        ramp = jnp.asarray(config.guide_ramp_steps, dtype=jnp.int32)

        after = jnp.maximum(phase_step - warm, jnp.asarray(0, dtype=jnp.int32)).astype(jnp.float32)
        ramp_f = jnp.maximum(ramp.astype(jnp.float32), 1.0)
        frac = jnp.clip(after / ramp_f, 0.0, 1.0)
        frac = jnp.where(ramp > 0, frac, (phase_step >= warm).astype(jnp.float32))

        guide_w = jnp.asarray(config.guide_weight, dtype=jnp.float32) * frac

        def loss_fn(cvae_params, latent_policy_params, mlp_residual_params):
            # ===== MLP Branch =====
            if guide_mode == "mlp":
                # 🔥 修改点：传入 a_base
                recon_res = mlp_residual_def.apply({"params": mlp_residual_params}, observations, a_base, training=True)
                recon_sq = jnp.sum((recon_res - residual_gt) ** 2, axis=-1)
                w_sum = jnp.sum(w_data)
                def weighted(_):
                    return jnp.sum(w_data * recon_sq) / (w_sum + 1e-8)
                def unweighted(_):
                    return jnp.mean(recon_sq)
                recon_loss = jax.lax.cond(w_sum > 1e-4, weighted, unweighted, operand=None)
                kl_loss = jnp.array(0.0)
            else:
                # CVAE Branch
                # 🔥 修改点：传入 a_base
                recon_res, enc_mu, enc_logvar = cvae_def.apply(
                    {"params": cvae_params}, observations, a_base, residual_gt, key_encode, training=True
                )
                recon_sq = jnp.sum((recon_res - residual_gt) ** 2, axis=-1)
                w_sum = jnp.sum(w_data)
                def weighted(_):
                    return jnp.sum(w_data * recon_sq) / (w_sum + 1e-8)
                def unweighted(_):
                    return jnp.mean(recon_sq)
                recon_loss = jax.lax.cond(w_sum > 1e-4, weighted, unweighted, operand=None)
                kl_loss = jnp.mean(cvae_kl_divergence(enc_mu, enc_logvar))

            guide_loss = jnp.array(0.0)
            latent_policy_loss = jnp.array(0.0)
            
            logs = {
                "Value/Q_Synth": jnp.array(0.0),
                "Loss/LatentKL": jnp.array(0.0),
                "Diagnostics/LatentSigma": jnp.array(0.0),
                "Loss/Proj_Core": jnp.array(0.0),
                "Diagnostics/DoProj": jnp.array(0.0),
                "Diagnostics/Proj_Entropy": jnp.array(0.0),
                "Diagnostics/Proj_Utilization": jnp.array(0.0),
                "Diagnostics/Proj_BestGain": jnp.array(0.0),
                "Value/Q_Prop_Avg": jnp.array(0.0),
                "Diagnostics/Adaptive_Guide_Weight": jnp.array(0.0),
            }

            if guide_mode in ["plas", "trainz", "grad", "mlp"]:
                q_params_guide = _subset_critics(state.q_params, config.guide_num_critics)
            else:
                q_params_guide = _subset_critics(state.q_params, config.proj_num_critics)

            # ===== MLP Guide =====
            if guide_mode == "mlp":
                # MLP 直接输出残差，用 MaxQ 引导
                gen_res = mlp_residual_def.apply({"params": mlp_residual_params}, observations, a_base, training=True)
                a_synth = jnp.clip(a_base + gen_res, -1.0, 1.0)

                q_synth, _, _ = compute_ensemble_q_robust(
                    critic_def, q_params_guide, observations, a_synth, config.uncertainty_weight
                )

                guide_loss = guide_w * (-jnp.mean(q_synth))
                logs["Value/Q_Synth"] = jnp.mean(q_synth)

            # ===== PLAS Guide =====
            elif guide_mode == "plas":
                z = latent_policy_def.apply({"params": latent_policy_params}, observations, a_base, training=True)
                decoder_params_frozen = jax.lax.stop_gradient(cvae_params)
                gen_res = _decode_residual(cvae_def, decoder_params_frozen, observations, a_base, z)
                a_synth = jnp.clip(a_base + gen_res, -1.0, 1.0)

                q_synth, _, _ = compute_ensemble_q_robust(
                    critic_def, q_params_guide, observations, a_synth, config.uncertainty_weight
                )

                guide_loss = guide_w * (-jnp.mean(q_synth))
                latent_policy_loss = jnp.array(0.0)
                logs["Value/Q_Synth"] = jnp.mean(q_synth)

            # ===== TrainZ Guide =====
            elif guide_mode == "trainz":
                B = observations.shape[0]
                M = config.trainz_num_samples
                sd = observations.shape[1]
                ad = a_base.shape[1]  # 或 a_base.shape[-1]

                mu, logstd = latent_policy_def.apply(
                    {"params": latent_policy_params}, observations, a_base, training=True
                )
                sigma = jnp.exp(logstd)

                eps = jax.random.normal(key_trainz, (B, M, config.cvae_latent_dim))
                raw_z = mu[:, None, :] + sigma[:, None, :] * eps
                z = config.latent_max * jnp.tanh(raw_z)
                z_flat = z.reshape(B * M, config.cvae_latent_dim)

                # ✅ 展开 obs/base_action 到 (B*M, ·)
                obs_rep = jnp.repeat(observations, M, axis=0)   # (B*M, sd)
                abase_rep = jnp.repeat(a_base, M, axis=0)       # (B*M, ad)

                decoder_params_frozen = jax.lax.stop_gradient(cvae_params)
                gen_res = _decode_residual(cvae_def, decoder_params_frozen, obs_rep, abase_rep, z_flat)
                a_synth = jnp.clip(abase_rep + gen_res, -1.0, 1.0)

                q_synth_flat, _, _ = compute_ensemble_q_robust(
                    critic_def, q_params_guide, obs_rep, a_synth, config.uncertainty_weight
                )
                q_synth = q_synth_flat.reshape(B, M).mean(axis=1)

                kl_z = jnp.mean(gaussian_kl_to_std_normal(mu, logstd))
                guide_loss = guide_w * (-jnp.mean(q_synth))
                latent_policy_loss = config.trainz_kl_beta * kl_z

                logs["Value/Q_Synth"] = jnp.mean(q_synth)
                logs["Loss/LatentKL"] = kl_z
                logs["Diagnostics/LatentSigma"] = jnp.mean(sigma)


            # ===== Grad Guide =====
            elif guide_mode == "grad":
                z_prior = jax.random.normal(key_misc, (observations.shape[0], config.cvae_latent_dim))
                gen_res = _decode_residual(cvae_def, cvae_params, observations, a_base, z_prior)
                a_synth = jnp.clip(a_base + gen_res, -1.0, 1.0)

                q_synth, _, _ = compute_ensemble_q_robust(
                    critic_def, q_params_guide, observations, a_synth, config.uncertainty_weight
                )

                if config.use_adaptive_guide:
                    q_scale = jax.lax.stop_gradient(jnp.mean(jnp.abs(q_synth)) + config.guide_weight_eps)
                    adaptive = jnp.minimum(guide_w / q_scale, config.guide_weight_clip)
                else:
                    adaptive = guide_w

                guide_loss = adaptive * (-jnp.mean(q_synth))
                logs["Value/Q_Synth"] = jnp.mean(q_synth)
                logs["Diagnostics/Adaptive_Guide_Weight"] = adaptive

            # ===== Proj Guide =====
            elif guide_mode == "proj":
                do_proj = (step_it % jnp.array(config.proj_period, dtype=jnp.int32)) == 0

                def proj_branch(_):
                    B = observations.shape[0]
                    K = config.proj_num_candidates

                    z_bk = jax.random.normal(key_proj, (B, K, config.cvae_latent_dim))
                    z = z_bk.reshape(B * K, config.cvae_latent_dim)

                    obs_rep = jnp.repeat(observations, K, axis=0)
                    abase_rep = jnp.repeat(a_base, K, axis=0)

                    proposal_params = state.cvae_target_params if config.proj_use_target_proposal else cvae_params
                    res_tgt = _decode_residual(cvae_def, proposal_params, obs_rep, abase_rep, z)
                    res_tgt = jax.lax.stop_gradient(res_tgt)
                    a_prop = jnp.clip(abase_rep + res_tgt, -1.0, 1.0)

                    q_prop, _, _ = compute_ensemble_q_robust(
                        critic_def, q_params_guide, obs_rep, a_prop, config.uncertainty_weight
                    )
                    q_prop = q_prop.reshape(B, K)

                    q_base_proj, _, _ = compute_ensemble_q_robust(
                        critic_def, q_params_guide, observations, a_base, config.uncertainty_weight
                    )

                    gain = (q_prop - q_base_proj[:, None]) if config.proj_use_relative_gain else q_prop
                    gain_scale = jax.lax.stop_gradient(jnp.mean(jnp.abs(gain), axis=1, keepdims=True) + 1e-6)
                    gain_norm = gain / gain_scale

                    if config.proj_only_positive:
                        pos = gain_norm > 0
                        has_pos = jnp.any(pos, axis=1, keepdims=True)
                        logits = jnp.where(pos, gain_norm / config.proj_temperature, -1e9)
                        w = _safe_softmax(logits, axis=1)
                        w = w * has_pos.astype(w.dtype)
                        w = w / (jnp.sum(w, axis=1, keepdims=True) + 1e-8)
                    else:
                        logits = gain_norm / config.proj_temperature
                        w = _safe_softmax(logits, axis=1)

                    w = jax.lax.stop_gradient(w)

                    proj_entropy = -jnp.mean(jnp.sum(w * jnp.log(w + 1e-8), axis=1))
                    best_gain = jnp.mean(jnp.max(gain, axis=1))
                    proj_util = jnp.mean(jnp.max(gain, axis=1) > 0)
                    q_prop_avg = jnp.mean(q_prop)

                    res_pred = _decode_residual(cvae_def, cvae_params, obs_rep, abase_rep, z)
                    res_pred = res_pred.reshape(B, K, -1)
                    res_tgt_bk = res_tgt.reshape(B, K, -1)

                    sq = jnp.mean((res_pred - res_tgt_bk) ** 2, axis=-1)
                    per_state = jnp.sum(w * sq, axis=1) / (2.0 * (config.proj_sigma ** 2))
                    proj_core = jnp.mean(per_state)

                    return (
                        guide_w * proj_core,
                        proj_core,
                        proj_entropy,
                        proj_util,
                        best_gain,
                        q_prop_avg,
                    )

                def skip_branch(_):
                    z0 = jnp.array(0.0)
                    return z0, z0, z0, z0, z0, z0

                guide_loss, proj_core, proj_entropy, proj_util, best_gain, q_prop_avg = jax.lax.cond(
                    do_proj, proj_branch, skip_branch, operand=None
                )

                logs["Loss/Proj_Core"] = proj_core
                logs["Diagnostics/DoProj"] = do_proj.astype(jnp.float32)
                logs["Diagnostics/Proj_Entropy"] = proj_entropy
                logs["Diagnostics/Proj_Utilization"] = proj_util
                logs["Diagnostics/Proj_BestGain"] = best_gain
                logs["Value/Q_Prop_Avg"] = q_prop_avg
                logs["Value/Q_Synth"] = q_prop_avg
                
            # ===== Loss Aggregation =====
            if guide_mode == "mlp":
                # MLP: recon + guide (MaxQ)
                total_mlp_loss = config.recon_weight * recon_loss + guide_loss
                total_cvae_loss = jnp.array(0.0)
                total_policy_loss = jnp.array(0.0)
                total_joint_loss = total_mlp_loss
            elif guide_mode == "proj":
                # Proj: CVAE recon + KL + guide
                total_cvae_loss = config.recon_weight * recon_loss + config.kl_weight * kl_loss + guide_loss
                total_policy_loss = jnp.array(0.0)
                total_mlp_loss = jnp.array(0.0)
                total_joint_loss = total_cvae_loss
            else:
                # PLAS/TrainZ/Grad: CVAE recon + KL, Latent Policy guide
                total_cvae_loss = config.recon_weight * recon_loss + config.kl_weight * kl_loss
                total_policy_loss = guide_loss + latent_policy_loss
                total_mlp_loss = jnp.array(0.0)
                total_joint_loss = total_cvae_loss + total_policy_loss

            if guide_mode != "mlp":
                logs["Diagnostics/Latent_Std"] = jnp.mean(jnp.exp(0.5 * enc_logvar))
            else:
                logs["Diagnostics/Latent_Std"] = jnp.array(0.0)
            logs["Diagnostics/GuideW"] = guide_w
            logs.update({
                "Loss/Recon": recon_loss,
                "Loss/KL": kl_loss,
                "Loss/Guide": guide_loss,
                "Loss/LatentPolicy": latent_policy_loss,
                "Loss/Total": total_joint_loss,
                "Value/Q_Base_Avg": jnp.mean(q_base_data),
                "Value/Q_Data_Avg": jnp.mean(q_real),
                "Diagnostics/Data_Utilization": jnp.mean(jax.lax.stop_gradient(adv_rel > 0)),
                "Diagnostics/Recon_Error_Raw": jnp.mean(recon_sq),
            })

            return total_joint_loss, logs


        (total_loss, metrics), grads = jax.value_and_grad(loss_fn, argnums=(0, 1, 2), has_aux=True)(
            state.cvae_params, state.latent_policy_params, state.mlp_residual_params
        )

        grads_cvae, grads_policy, grads_mlp = grads

        # Update CVAE
        updates_cvae, new_cvae_opt = cvae_tx.update(grads_cvae, state.cvae_opt_state, state.cvae_params)
        new_cvae_params = optax.apply_updates(state.cvae_params, updates_cvae)

        # Update Latent Policy
        updates_policy, new_policy_opt = latent_policy_tx.update(grads_policy, state.latent_policy_opt_state, state.latent_policy_params)
        new_policy_params = optax.apply_updates(state.latent_policy_params, updates_policy)

        # Update MLP Residual
        updates_mlp, new_mlp_opt = mlp_residual_tx.update(grads_mlp, state.mlp_residual_opt_state, state.mlp_residual_params)
        new_mlp_params = optax.apply_updates(state.mlp_residual_params, updates_mlp)

        # Target update
        if guide_mode == "mlp":
            new_target = state.cvae_target_params
        else:
            def do_target_update(args):
                new_p, old_t = args
                return optax.incremental_update(new_p, old_t, config.target_tau)

            do_update = (step_it % jnp.array(config.target_update_period, dtype=jnp.int32)) == 0
            new_target = jax.lax.cond(
                do_update,
                do_target_update,
                lambda x: x[1],
                (new_cvae_params, state.cvae_target_params)
            )

        new_state = SPARState(
            cvae_params=new_cvae_params,
            cvae_target_params=new_target,
            cvae_opt_state=new_cvae_opt,
            latent_policy_params=new_policy_params,
            latent_policy_opt_state=new_policy_opt,
            mlp_residual_params=new_mlp_params,
            mlp_residual_opt_state=new_mlp_opt,
            actor_params=state.actor_params,
            q_params=state.q_params,
            total_it=step_it,
        )

        return new_state, metrics, rng_out

    return train_step


# ==================== Inference ====================
def make_spar_policy_fn(
    cvae_def: CVAE,
    latent_policy_def: LatentPolicy,
    mlp_residual_def: MLPResidual,
    critic_def: Critic,
    actor_def,
    config: SPARConfig,
    deterministic_actor: bool,
):
    @jax.jit
    def policy(cvae_params, actor_params, q_params, latent_policy_params, mlp_residual_params, obs, rng):
        rng, key = jax.random.split(rng)
        
        # 获取 base action (保持 batch 维度)
        a_base = _get_base_action(actor_def, actor_params, obs, deterministic_actor)  # (1, action_dim)

        # ==================== 无 Rectification 分支 ====================
        if not config.use_rectification:
            if config.guide_mode == "mlp":
                residual = mlp_residual_def.apply({"params": mlp_residual_params}, obs, a_base, training=False)
            elif config.guide_mode == "plas":
                z = latent_policy_def.apply({"params": latent_policy_params}, obs, a_base, training=False)
                residual = _decode_residual(cvae_def, cvae_params, obs, a_base, z)
            elif config.guide_mode == "trainz":
                mu, _ = latent_policy_def.apply({"params": latent_policy_params}, obs, a_base, training=False)
                z = config.latent_max * jnp.tanh(mu)
                residual = _decode_residual(cvae_def, cvae_params, obs, a_base, z)
            else:
                z = jnp.zeros((1, config.cvae_latent_dim))
                residual = _decode_residual(cvae_def, cvae_params, obs, a_base, z)

            action = jnp.clip(a_base + residual, -1.0, 1.0)[0]  # 取出 (action_dim,)
            
            diagnostics = {
                "use_candidate": jnp.array(False),
                "residual_norm": jnp.linalg.norm(residual),
                "q_gain": jnp.array(0.0),
            }
            return action, rng, diagnostics

        # ==================== Rectification 分支 ====================
        K = config.num_candidates
        
        # 生成候选残差
        if config.guide_mode == "mlp":
            res_det = mlp_residual_def.apply({"params": mlp_residual_params}, obs, a_base, training=False)  # (1, action_dim)
            res_noise = jax.random.normal(key, (K, res_det.shape[-1])) * 0.1  # (K, action_dim)
            # 🔥 修复：显式广播
            res_det_expanded = jnp.broadcast_to(res_det, (K, res_det.shape[-1]))
            residuals = res_det_expanded + res_noise  # (K, action_dim)
            
        elif config.guide_mode == "plas":
            z_det = latent_policy_def.apply({"params": latent_policy_params}, obs, a_base, training=False)  # (1, latent_dim)
            z_noise = jax.random.normal(key, (K, config.cvae_latent_dim)) * 0.1
            z_det_expanded = jnp.broadcast_to(z_det, (K, config.cvae_latent_dim))
            z_candidates = z_det_expanded + z_noise  # (K, latent_dim)
            
            obs_rep = jnp.repeat(obs, K, axis=0)
            a_base_rep = jnp.repeat(a_base, K, axis=0)
            residuals = _decode_residual(cvae_def, cvae_params, obs_rep, a_base_rep, z_candidates)
            
        elif config.guide_mode == "trainz":
            mu, logstd = latent_policy_def.apply({"params": latent_policy_params}, obs, a_base, training=False)
            sigma = jnp.exp(logstd)  # (1, latent_dim)
            
            eps = jax.random.normal(key, (K, config.cvae_latent_dim))
            mu_expanded = jnp.broadcast_to(mu, (K, config.cvae_latent_dim))
            sigma_expanded = jnp.broadcast_to(sigma, (K, config.cvae_latent_dim))
            raw_z = mu_expanded + sigma_expanded * eps
            z_candidates = config.latent_max * jnp.tanh(raw_z)
            
            obs_rep = jnp.repeat(obs, K, axis=0)
            a_base_rep = jnp.repeat(a_base, K, axis=0)
            residuals = _decode_residual(cvae_def, cvae_params, obs_rep, a_base_rep, z_candidates)
            
        else:  # grad / proj
            z_candidates = jax.random.normal(key, (K, config.cvae_latent_dim))
            obs_rep = jnp.repeat(obs, K, axis=0)
            a_base_rep = jnp.repeat(a_base, K, axis=0)
            residuals = _decode_residual(cvae_def, cvae_params, obs_rep, a_base_rep, z_candidates)

        # 🔥 修复：显式广播 a_base
        a_base_expanded = jnp.broadcast_to(a_base, (K, a_base.shape[-1]))  # (K, action_dim)
        candidates = jnp.clip(a_base_expanded + residuals, -1.0, 1.0)  # (K, action_dim)

        # 计算 Q 值
        q_params_rect = _subset_critics(q_params, config.rect_num_critics)
        obs_rep = jnp.repeat(obs, K, axis=0)  # (K, state_dim)
        
        q_cand, _, _ = compute_ensemble_q_robust(
            critic_def, q_params_rect, obs_rep, candidates, config.uncertainty_weight
        )  # (K,)
        
        q_base, _, _ = compute_ensemble_q_robust(
            critic_def, q_params_rect, obs, a_base, config.uncertainty_weight
        )  # (1,)

        # 🔥 修复：确保所有变量都是标量
        best_idx = jnp.argmax(q_cand)  # 标量
        best_action = candidates[best_idx]  # (action_dim,)
        
        # 🔥 关键修复：显式转换为标量
        q_base_scalar = jnp.squeeze(q_base)  # () 标量
        q_best_scalar = q_cand[best_idx]  # 标量
        gain = q_best_scalar - q_base_scalar  # 标量

        # 计算接受条件
        if config.use_dual_threshold:
            rel_gain = gain / (jnp.abs(q_base_scalar) + 1e-5)
            accept = (gain > config.rectification_abs_threshold) & (rel_gain > config.rectification_rel_threshold)
        else:
            accept = gain > 0

        # 🔥 修复：确保 a_base[0] 是 (action_dim,)
        a_base_single = jnp.squeeze(a_base, axis=0)  # (action_dim,)
        
        # 最终动作选择
        final_action = jnp.where(accept, best_action, a_base_single)
        
        diagnostics = {
            "use_candidate": accept,
            "residual_norm": jnp.linalg.norm(final_action - a_base_single),
            "q_gain": gain,
        }

        return final_action, rng, diagnostics

    return policy



# ==================== Eval ====================
# ==================== Eval ====================
def eval_spar_actor(env, policy_fn, state, config, rng):
    max_action = float(env.action_space.high[0])
    episode_rewards = []
    accept_counts = 0
    total_steps = 0
    q_gains = []
    residual_norms = []

    # 🔧 修改 1: 仅在评估开始前设置一次种子
    # 这与 PyTorch 的标准评估逻辑一致
    try:
        env.seed(config.seed)
    except Exception:
        pass

    for i in range(config.n_episodes):
        # 🔧 修改 2: 移除循环内的 env.seed() 调用
        # 让环境随机性自然演进，而不是强制重置为特定种子
        
        try:
            obs = env.reset()
        except Exception:
            # 处理某些旧版 Gym 在 reset 失败时的回退逻辑
            obs = env.reset()

        done = False
        ep_ret = 0.0
        while not done:
            obs_jax = jnp.asarray(obs, dtype=jnp.float32)[None, :]
            action, rng, diag = policy_fn(
                state.cvae_params, state.actor_params, state.q_params, 
                state.latent_policy_params, state.mlp_residual_params, obs_jax, rng
            )
            if bool(diag["use_candidate"]):
                accept_counts += 1
            q_gains.append(float(diag["q_gain"]))
            residual_norms.append(float(diag["residual_norm"]))
            total_steps += 1

            action_np = np.clip(np.asarray(action) * max_action, -max_action, max_action)
            obs, r, done, _ = env.step(action_np)
            ep_ret += r
        episode_rewards.append(ep_ret)

    return np.array(episode_rewards), rng, {
        "Acceptance_Rate": accept_counts / max(total_steps, 1),
        "Avg_Q_Gain": np.mean(q_gains) if len(q_gains) > 0 else 0.0,
        "Residual_Norm": np.mean(residual_norms) if len(residual_norms) > 0 else 0.0,
        "Positive_Gain_Ratio": np.mean(np.array(q_gains) > 0) if len(q_gains) > 0 else 0.0,
    }




# ==================== Checkpoint ====================
def save_spar_checkpoint(
    path: str,
    state: SPARState,
    eval_score: float,
    step: int,
    config: SPARConfig,
    state_mean: np.ndarray,
    state_std: np.ndarray,
):
    """Save checkpoint with explicit state_mean/std for reproducibility."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)

    cpu_state = jax.device_get(state)
    cpu_state = jax.tree_util.tree_map(lambda x: np.asarray(x) if hasattr(x, "dtype") else x, cpu_state)

    checkpoint = {
        "state": cpu_state,
        "eval_score": eval_score,
        "step": step,
        "config": {
            "env": config.env,
            "state_mean": np.asarray(state_mean),
            "state_std": np.asarray(state_std),
            "cvae_latent_dim": config.cvae_latent_dim,
            "cvae_hidden_dims": list(config.cvae_hidden_dims),
            "mlp_residual_hidden_dims": list(config.mlp_residual_hidden_dims),
            "uncertainty_weight": config.uncertainty_weight,
            "num_candidates": config.num_candidates,
            "guide_mode": config.guide_mode,
            "proj_num_candidates": config.proj_num_candidates,
            "proj_period": config.proj_period,
            "proj_temperature": config.proj_temperature,
            "proj_use_target_proposal": config.proj_use_target_proposal,
            "target_tau": config.target_tau,
            "data_num_critics": config.data_num_critics,
            "proj_num_critics": config.proj_num_critics,
            "rect_num_critics": config.rect_num_critics,
            "relative_advantage_scale": config.relative_advantage_scale,
        },
    }

    with open(path, "wb") as f:
        pickle.dump(checkpoint, f)


# ==================== Main ====================
@pyrallis.wrap()
def train(config: SPARConfig):
    # Config 序列化处理
    def config_to_dict(cfg):
        d = asdict(cfg)
        for k, v in d.items():
            if isinstance(v, tuple):
                d[k] = list(v)
        return d

    # 1. Load IQL
    if not config.base_model_path or config.base_model_path == "AUTO":
        print(f"🔍 Searching for IQL checkpoint for {config.env}...")
        base_model_path = ""
        if config.env.startswith("antmaze"):
            base_model_path = find_latest_checkpoint(
            log_base_dir=config.log_dir,
            env_name=config.env,
            method=["IQL", "SARSA"],
            checkpoint_name="checkpoint_best.pkl",
        )
        else:
            base_model_path = find_latest_checkpoint(
                log_base_dir=config.log_dir,
                env_name=config.env,
                method=["IQL", "SARSA"],
                checkpoint_name="checkpoint_last.pkl",
            )
    else:
        base_model_path = config.base_model_path
        if not os.path.exists(base_model_path):
            raise FileNotFoundError(f"Checkpoint not found: {base_model_path}")

    print("=" * 80)
    print("🔧 Loading IQL Base Model...")
    print(f"   Path: {base_model_path}")

    iql_state, iql_metadata = load_iql_checkpoint(base_model_path)
    iql_config = iql_metadata["config"]
    
    if iql_config["env"] != config.env:
        raise ValueError(f"Environment mismatch! IQL={iql_config['env']} SPAR={config.env}")

    state_mean = iql_config["state_mean"]
    state_std = iql_config["state_std"]
    num_critics = iql_config["num_critics"]
    iql_deterministic = iql_config.get("iql_deterministic", False)
    
    start_step = 1_000_000
    config.phase2_start_step = start_step  # NEW


    print("✅ IQL loaded:")
    print(f"   Eval score: {iql_metadata.get('eval_score', 'N/A')}")
    print(f"   Step: {iql_metadata.get('step', 'N/A')}")
    print(f"   Critics: {num_critics}, Deterministic: {iql_deterministic}")
    print("=" * 80)

    # 2. Env
    env = gym.make(config.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)
    if config.normalize_reward:
        modify_reward(dataset, config.env)
    dataset["observations"] = normalize_states(dataset["observations"], state_mean, state_std)
    dataset["next_observations"] = normalize_states(dataset["next_observations"], state_mean, state_std)
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)

    replay_buffer = ReplayBuffer(state_dim, action_dim, config.buffer_size)
    replay_buffer.load_d4rl_dataset(dataset)

    # 3. Networks
    set_seed(config.seed, env)
    rng = jax.random.PRNGKey(config.seed)

    critic_def = Critic(state_dim=state_dim, action_dim=action_dim)
    if iql_deterministic:
        actor_def = DeterministicPolicy(state_dim=state_dim, act_dim=action_dim, max_action=1.0)
    else:
        actor_def = GaussianPolicy(state_dim=state_dim, act_dim=action_dim, max_action=1.0)

    cvae_def = CVAE(
        action_dim=action_dim,
        hidden_dims=config.cvae_hidden_dims,
        latent_dim=config.cvae_latent_dim,
        use_layernorm=config.cvae_layernorm,
    )

    latent_policy_def = LatentPolicy(
        latent_dim=config.cvae_latent_dim,
        hidden_dims=config.latent_policy_hidden_dims,
        is_stochastic=(config.guide_mode == "trainz"),
        latent_max=config.latent_max,
        logstd_min=config.trainz_logstd_min,
        logstd_max=config.trainz_logstd_max,
    )

    mlp_residual_def = MLPResidual(
        action_dim=action_dim,
        hidden_dims=config.mlp_residual_hidden_dims,
    )

    rng, key_cvae, key_lp, key_mlp = jax.random.split(rng, 4)
    dummy_obs = jnp.zeros((1, state_dim))
    dummy_act = jnp.zeros((1, action_dim))

    cvae_params = cvae_def.init(key_cvae, dummy_obs, dummy_act, dummy_act, key_cvae, training=True)["params"]
    latent_policy_params = latent_policy_def.init(key_lp, dummy_obs, dummy_act)["params"]
    mlp_residual_params = mlp_residual_def.init(key_mlp, dummy_obs, dummy_act)["params"]

# 定义一个带裁剪的优化器辅助函数
    def make_optimizer(lr):
        base = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(lr),
        )
        return optax.apply_if_finite(base, max_consecutive_errors=10)


    cvae_tx = make_optimizer(config.cvae_lr)
    latent_policy_tx = make_optimizer(config.latent_policy_lr)
    mlp_residual_tx = make_optimizer(config.mlp_residual_lr)


    spar_state = SPARState(
        cvae_params=cvae_params,
        cvae_target_params=cvae_params,
        cvae_opt_state=cvae_tx.init(cvae_params),
        latent_policy_params=latent_policy_params,
        latent_policy_opt_state=latent_policy_tx.init(latent_policy_params),
        mlp_residual_params=mlp_residual_params,
        mlp_residual_opt_state=mlp_residual_tx.init(mlp_residual_params),
        actor_params=iql_state.actor_params,
        q_params=iql_state.q_params,
        total_it=jnp.array(start_step, dtype=jnp.int32),
    )

    # 4. Train Step & Policy
    train_step = make_spar_train_step(
        cvae_def, latent_policy_def, mlp_residual_def, critic_def, actor_def,
        cvae_tx, latent_policy_tx, mlp_residual_tx, config=config, deterministic_actor=iql_deterministic
    )

    policy_fn = make_spar_policy_fn(
        cvae_def, latent_policy_def, mlp_residual_def, critic_def, actor_def, config, iql_deterministic
    )

    # 5. WandB Init
    hyperparams = config_to_dict(config)
    
    env_parts = config.env.split("-")
    domain = env_parts[0] if len(env_parts) > 0 else "unknown"
    dataset_type = "-".join(env_parts[1:]) if len(env_parts) > 1 else "unknown"
    
    tags = [domain, dataset_type, "SPAR"]
    if config.guide_mode:
        tags.append(f"gm:{config.guide_mode.lower()}")
    
    run_name, log_dir = wandb_init(
        config=config,
        log_base_dir=config.log_dir,
        env_name=config.env,
        method="SPAR",
        hyperparams=hyperparams,
        tags=tags,
    )

    wandb.config.update({
        "runtime/base_model_path_resolved": os.path.abspath(base_model_path),
        "runtime/iql_ckpt_step": iql_metadata.get("step", -1),
        "runtime/iql_ckpt_eval_score": iql_metadata.get("eval_score", -1.0),
        "runtime/iql_deterministic": iql_deterministic,
    }, allow_val_change=True)

    # 6. Checkpoints
    if config.checkpoints_path is not None:
        checkpoints_dir = config.checkpoints_path
    elif log_dir is not None:
        checkpoints_dir = os.path.join(log_dir, "checkpoints")
    else:
        checkpoints_dir = os.path.join("checkpoints", config.env)
    os.makedirs(checkpoints_dir, exist_ok=True)

    config_save_path = os.path.join(checkpoints_dir, "config.yaml")
    with open(config_save_path, "w") as f:
        yaml.dump(hyperparams, f, default_flow_style=False)

    print(f"💾 Checkpoints will be saved to: {checkpoints_dir}")

    # 7. Baseline Eval
    print("=" * 80)
    print("📊 Running Baseline Evaluation @ Step 1M...")
    jax.block_until_ready(spar_state.cvae_params)
    
    eval_scores, rng, eval_diag = eval_spar_actor(env, policy_fn, spar_state, config, rng)
    eval_score = float(eval_scores.mean())
    eval_std = float(eval_scores.std())
    normalized_eval_score = float(env.get_normalized_score(eval_score) * 100.0)

    print(f"   D4RL Score: {normalized_eval_score:.2f}")
    print(f"   AcceptRate={eval_diag['Acceptance_Rate']:.1%}")
    print(f"   Avg Q Gain: {eval_diag['Avg_Q_Gain']:.4f}")
    print("=" * 80)

    wandb.log(
        {
            "eval/episode_reward": eval_score,
            "eval/episode_reward_std": eval_std,
            "eval/d4rl_normalized_score": normalized_eval_score,
            "eval/episode_rewards_distribution": wandb.Histogram(eval_scores),
            "eval/acceptance_rate": eval_diag["Acceptance_Rate"],
            "eval/avg_q_gain": eval_diag["Avg_Q_Gain"],
            "eval/residual_norm": eval_diag["Residual_Norm"],
            "eval/positive_gain_ratio": eval_diag["Positive_Gain_Ratio"],
        },
        step=start_step,
    )

    best_eval_score = normalized_eval_score

    # 8. Training Loop
    print("=" * 80)
    print(f"🚀 Training SPAR Phase 2 | Env: {config.env} | Seed: {config.seed}")
    print(f"📈 Starting counter from: {start_step}")
    print(f"🎯 GuideMode={config.guide_mode} | UncLambda={config.uncertainty_weight}")
    print(f"⚡ Critic subsets: data={config.data_num_critics}, guide={config.guide_num_critics}, proj={config.proj_num_critics}, rect={config.rect_num_critics}")
    if config.guide_mode == "proj":
        print(
            f"🎯 Proj: K={config.proj_num_candidates}, period={config.proj_period}, "
            f"pT={config.proj_temperature}, target={config.proj_use_target_proposal}, tau={config.target_tau}"
        )
    print("=" * 80)

    pbar = tqdm(range(config.cvae_steps), desc=f"SPAR Phase 2 [{config.env}]", dynamic_ncols=True)

    for local_step in pbar:
        batch = replay_buffer.sample(config.batch_size)
        spar_state, metrics, rng = train_step(spar_state, batch, rng)

        global_step = start_step + local_step + 1

        if global_step % config.log_freq == 0:
            metrics_np = jax.device_get(metrics)
            wandb.log(metrics_np, step=global_step)

            postfix = {
                "Step": global_step,
                "Recon": f"{float(metrics_np['Loss/Recon']):.3f}",
                "Guide": f"{float(metrics_np['Loss/Guide']):.3f}",
            }
            if config.guide_mode == "proj":
                postfix["DoProj"] = f"{float(metrics_np.get('Diagnostics/DoProj', 0.0)):.0f}"
                postfix["Proj"] = f"{float(metrics_np.get('Loss/Proj_Core', 0.0)):.3f}"
                postfix["Use%"] = f"{float(metrics_np.get('Diagnostics/Proj_Utilization', 0.0)):.1%}"
            postfix["DataUse%"] = f"{float(metrics_np['Diagnostics/Data_Utilization']):.1%}"
            pbar.set_postfix(postfix)

        if (global_step % config.eval_freq) == 0:
            jax.block_until_ready(spar_state.cvae_params)

            eval_scores, rng, eval_diag = eval_spar_actor(env, policy_fn, spar_state, config, rng)
            eval_score = float(eval_scores.mean())
            eval_std = float(eval_scores.std())
            normalized_eval_score = float(env.get_normalized_score(eval_score) * 100.0)

            pbar.write(f"\n{'=' * 80}")
            pbar.write(
                f"[{config.env}] Step {global_step}: D4RL={normalized_eval_score:.2f}, "
                f"AcceptRate={eval_diag['Acceptance_Rate']:.1%}, "
                f"Q_Gain={eval_diag['Avg_Q_Gain']:.4f}"
            )

            is_best = normalized_eval_score > best_eval_score
            if is_best:
                best_eval_score = normalized_eval_score
                pbar.write(f"🏆 New best: {best_eval_score:.3f}")
            pbar.write(f"{'=' * 80}\n")

            wandb.log(
                {
                    "eval/episode_reward": eval_score,
                    "eval/episode_reward_std": eval_std,
                    "eval/d4rl_normalized_score": normalized_eval_score,
                    "eval/episode_rewards_distribution": wandb.Histogram(eval_scores),
                    "eval/acceptance_rate": eval_diag["Acceptance_Rate"],
                    "eval/avg_q_gain": eval_diag["Avg_Q_Gain"],
                    "eval/residual_norm": eval_diag["Residual_Norm"],
                    "eval/positive_gain_ratio": eval_diag["Positive_Gain_Ratio"],
                },
                step=global_step,
            )

            ckpt_path = os.path.join(checkpoints_dir, f"checkpoint_{global_step}.pkl")
            save_spar_checkpoint(ckpt_path, spar_state, normalized_eval_score, global_step, config, state_mean, state_std)

            last_path = os.path.join(checkpoints_dir, "checkpoint_last.pkl")
            save_spar_checkpoint(last_path, spar_state, normalized_eval_score, global_step, config, state_mean, state_std)

            if is_best:
                best_path = os.path.join(checkpoints_dir, "checkpoint_best.pkl")
                save_spar_checkpoint(best_path, spar_state, normalized_eval_score, global_step, config, state_mean, state_std)
                pbar.write(f"💾 Saved: {best_path}")

    pbar.close()
    wandb.finish()
    print("\n🎉 Training completed!")


if __name__ == "__main__":
    train()

