import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import random
import pickle
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
from tqdm import tqdm

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax import struct

# Import run utilities
from run_utils import wandb_init

import wandb

# ==================== JAX Configuration ====================
jax.config.update('jax_enable_x64', False)
jax.config.update('jax_default_matmul_precision', 'high')

TensorBatch = List[jnp.ndarray]

EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0


@dataclass
class TrainConfig:
    # Environment
    env: str = "halfcheetah-medium-expert-v2"
    seed: int = 0
    
    # Ensemble Critic
    num_critics: int = 10
    
    # IQL hyperparameters
    discount: float = 0.99
    tau: float = 0.005
    beta: float = 3.0
    iql_tau: float = 0.7
    iql_deterministic: bool = False
    
    # Training
    max_timesteps: int = int(1e6)
    buffer_size: int = 2_000_000
    batch_size: int = 256
    
    # Normalization
    normalize: bool = True
    normalize_reward: bool = False
    
    # Learning rates
    vf_lr: float = 3e-4
    qf_lr: float = 3e-4
    actor_lr: float = 3e-4
    actor_dropout: Optional[float] = None
    
    # Evaluation
    eval_freq: int = int(5e3)
    n_episodes: int = 10
    
    # Checkpoints
    checkpoints_path: Optional[str] = None
    load_model: str = ""
    
    # Logging (🔧 修正：添加 WandB 字段)
    log_dir: str = "runs"
    project: str = "CORL"
    group: str = "IQL-D4RL"
    
    # Device (保留但不影响 JAX)
    device: str = "cuda"

    def __post_init__(self):
        pass



# ==================== Utility Functions ====================


def soft_update_params(target_params, source_params, tau: float):
    """Polyak averaging for target network update"""
    return jax.tree_map(
        lambda t, s: (1.0 - tau) * t + tau * s, target_params, source_params
    )


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (state - state_mean) / state_std

    def scale_reward(reward):
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


def set_seed(seed: int, env: Optional[gym.Env] = None):
    if env is not None:
        try:
            env.seed(seed)
            env.action_space.seed(seed)
        except Exception:
            pass

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset, env_name, max_episode_steps=1000):
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0


def asymmetric_l2_loss(u: jnp.ndarray, tau: float) -> jnp.ndarray:
    """IQL asymmetric loss: mean(|tau - 1[u<0]| * u^2)"""
    weight = jnp.abs(tau - (u < 0).astype(jnp.float32))
    return jnp.mean(weight * (u**2))


# ==================== Replay Buffer (CPU) ====================


# ==================== Replay Buffer (CPU) with Async Transfer ====================

class ReplayBuffer:
    """CPU-based replay buffer with async GPU transfer."""

    def __init__(self, state_dim: int, action_dim: int, buffer_size: int):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self._actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
        self._rewards = np.zeros((buffer_size, 1), dtype=np.float32)
        self._next_states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self._dones = np.zeros((buffer_size, 1), dtype=np.float32)

    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError("Replay buffer is smaller than dataset!")

        self._states[:n_transitions] = data["observations"].astype(np.float32)
        self._actions[:n_transitions] = data["actions"].astype(np.float32)
        self._rewards[:n_transitions] = data["rewards"][..., None].astype(np.float32)
        self._next_states[:n_transitions] = data["next_observations"].astype(
            np.float32
        )
        self._dones[:n_transitions] = data["terminals"][..., None].astype(np.float32)

        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        """Sample and return as JAX arrays (async transfer to GPU)"""
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        
        # 🔧 关键优化：使用 jax.device_put_sharded 或直接返回 NumPy 数组
        # JAX 会自动异步传输，不会阻塞
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        
        # 🔧 使用 jnp.asarray 但不强制同步
        # JAX 会在需要时自动传输（lazy transfer）
        return [
            jnp.asarray(states),
            jnp.asarray(actions),
            jnp.asarray(rewards),
            jnp.asarray(next_states),
            jnp.asarray(dones),
        ]

    def add_transition(self):
        raise NotImplementedError


# ==================== Neural Networks ====================


class MLP(nn.Module):
    dims: List[int]
    activation_fn: Callable = nn.relu
    output_activation_fn: Optional[Callable] = None
    squeeze_output: bool = False
    dropout: Optional[float] = None

    @nn.compact
    def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
        n_dims = len(self.dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        for i in range(n_dims - 2):
            # 🔧 修改：对齐 PyTorch 的 Kaiming Uniform 初始化
            x = nn.Dense(
                self.dims[i + 1], 
                kernel_init=nn.initializers.kaiming_uniform()
            )(x)
            x = self.activation_fn(x)
            if self.dropout is not None:
                x = nn.Dropout(rate=self.dropout)(x, deterministic=not training)

        # 🔧 修改：输出层也使用 Kaiming Uniform
        x = nn.Dense(
            self.dims[-1], 
            kernel_init=nn.initializers.kaiming_uniform()
        )(x)
        
        if self.output_activation_fn is not None:
            x = self.output_activation_fn(x)

        if self.squeeze_output:
            if self.dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            x = jnp.squeeze(x, axis=-1)
        return x


class GaussianPolicy(nn.Module):
    state_dim: int
    act_dim: int
    max_action: float
    hidden_dim: int = 256
    n_hidden: int = 2
    dropout: Optional[float] = None

    def setup(self):
        self.net = MLP(
            dims=[self.state_dim, *([self.hidden_dim] * self.n_hidden), self.act_dim],
            output_activation_fn=jnp.tanh,
            dropout=self.dropout,
        )
        self.log_std = self.param("log_std", nn.initializers.zeros, (self.act_dim,))

    def __call__(
        self, obs: jnp.ndarray, training: bool = False
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        mean = self.net(obs, training=training)
        log_std = jnp.clip(self.log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = jnp.exp(log_std)
        return mean, std


class DeterministicPolicy(nn.Module):
    state_dim: int
    act_dim: int
    max_action: float
    hidden_dim: int = 256
    n_hidden: int = 2
    dropout: Optional[float] = None

    def setup(self):
        self.net = MLP(
            dims=[self.state_dim, *([self.hidden_dim] * self.n_hidden), self.act_dim],
            output_activation_fn=jnp.tanh,
            dropout=self.dropout,
        )

    def __call__(self, obs: jnp.ndarray, training: bool = False) -> jnp.ndarray:
        return self.net(obs, training=training)


class Critic(nn.Module):
    """Single Q-network (will be ensembled via vmap)"""
    state_dim: int
    action_dim: int
    hidden_dim: int = 256
    n_hidden: int = 2

    @nn.compact
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        sa = jnp.concatenate([state, action], axis=-1)
        dims = [
            self.state_dim + self.action_dim,
            *([self.hidden_dim] * self.n_hidden),
            1,
        ]
        q = MLP(dims=dims, squeeze_output=True)(sa, training=False)
        return q


class ValueFunction(nn.Module):
    state_dim: int
    hidden_dim: int = 256
    n_hidden: int = 2

    @nn.compact
    def __call__(self, state: jnp.ndarray) -> jnp.ndarray:
        dims = [self.state_dim, *([self.hidden_dim] * self.n_hidden), 1]
        v = MLP(dims=dims, squeeze_output=True)(state, training=False)
        return v


def gaussian_log_prob(
    actions: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray
) -> jnp.ndarray:
    """Gaussian log probability matching PyTorch Normal distribution"""
    z = (actions - mean) / std
    log2pi = jnp.log(2.0 * jnp.pi)
    lp = -0.5 * (z**2 + log2pi) - jnp.log(std)
    return jnp.sum(lp, axis=-1)


# ==================== Training State ====================


@struct.dataclass
class IQLState:
    actor_params: Any
    q_params: Any  # Ensemble: PyTree with leading dimension N
    q_target_params: Any
    v_params: Any
    actor_opt_state: optax.OptState
    q_opt_state: optax.OptState
    v_opt_state: optax.OptState
    total_it: jnp.ndarray


# ==================== Training Step ====================


def make_train_step(
    actor_def,
    critic_def,
    v_def,
    actor_tx,
    q_tx,
    v_tx,
    *,
    discount: float,
    tau: float,
    beta: float,
    iql_tau: float,
    deterministic_actor: bool,
):
    """Create JIT-compiled training step with Ensemble Critic (use first 2 for V target)"""

    # Ensemble forward pass using vmap
    def ensemble_critic_apply(q_params, state, action):
        """
        Apply all N critics in parallel using vmap.
        Args:
            q_params: PyTree with leading dimension N
            state: (B, state_dim)
            action: (B, action_dim)
        Returns:
            q_values: (N, B)
        """
        return jax.vmap(
            lambda p: critic_def.apply({"params": p}, state, action)
        )(q_params)

    @jax.jit
    def train_step(state: IQLState, batch: TensorBatch, rng: jax.Array):
        observations, actions, rewards, next_observations, dones = batch
        rewards_ = jnp.squeeze(rewards, axis=-1)
        dones_ = jnp.squeeze(dones, axis=-1)

        # ===== Compute next_v BEFORE vf update (matching PyTorch IQL) =====
        next_v = v_def.apply({"params": state.v_params}, next_observations)

        # ===== V update (use only first 2 critics like TwinQ) =====
        q_target_all = ensemble_critic_apply(
            state.q_target_params, observations, actions
        )  # (N, B)
        
        # Extract first 2 critics for TwinQ-style target
        q_target_twin = q_target_all[:2]  # (2, B)
        target_q = jnp.minimum(q_target_twin[0], q_target_twin[1])  # (B,)

        def v_loss_fn(v_params):
            v = v_def.apply({"params": v_params}, observations)
            adv = target_q - v
            v_loss = asymmetric_l2_loss(adv, iql_tau)
            return v_loss, adv

        (v_loss, adv), v_grads = jax.value_and_grad(v_loss_fn, has_aux=True)(
            state.v_params
        )
        v_updates, new_v_opt_state = v_tx.update(
            v_grads, state.v_opt_state, state.v_params
        )
        new_v_params = optax.apply_updates(state.v_params, v_updates)

        # ===== Q update (all N critics regress to same target) =====
        targets = rewards_ + (1.0 - dones_) * discount * jax.lax.stop_gradient(next_v)

        def q_loss_fn(q_params):
            q_all = ensemble_critic_apply(q_params, observations, actions)  # (N, B)
            # Broadcast targets: (B,) -> (1, B) for broadcasting
            targets_expanded = jnp.expand_dims(targets, axis=0)  # (1, B)
            # MSE loss for each critic: (N, B)
            squared_errors = (q_all - targets_expanded) ** 2
            # Mean over batch for each critic: (N,)
            losses = jnp.mean(squared_errors, axis=1)
            # Mean over all critics: scalar
            return jnp.mean(losses)

        q_loss, q_grads = jax.value_and_grad(q_loss_fn)(state.q_params)
        q_updates, new_q_opt_state = q_tx.update(
            q_grads, state.q_opt_state, state.q_params
        )
        new_q_params = optax.apply_updates(state.q_params, q_updates)

        # Polyak update target Q (all N critics)
        new_q_target_params = soft_update_params(
            state.q_target_params, new_q_params, tau
        )

        # ===== Actor update =====
        exp_adv = jnp.clip(
            jnp.exp(beta * jax.lax.stop_gradient(adv)), a_max=EXP_ADV_MAX
        )

        rng, dropout_key = jax.random.split(rng)

        if deterministic_actor:
            def actor_loss_fn(actor_params):
                pred = actor_def.apply(
                    {"params": actor_params},
                    observations,
                    training=True,
                    rngs={"dropout": dropout_key},
                )
                bc_losses = jnp.sum((pred - actions) ** 2, axis=-1)
                return jnp.mean(exp_adv * bc_losses)
        else:
            def actor_loss_fn(actor_params):
                mean, std = actor_def.apply(
                    {"params": actor_params},
                    observations,
                    training=True,
                    rngs={"dropout": dropout_key},
                )
                bc_losses = -gaussian_log_prob(actions, mean, std)
                return jnp.mean(exp_adv * bc_losses)

        actor_loss, actor_grads = jax.value_and_grad(actor_loss_fn)(state.actor_params)
        actor_updates, new_actor_opt_state = actor_tx.update(
            actor_grads, state.actor_opt_state, state.actor_params
        )
        new_actor_params = optax.apply_updates(state.actor_params, actor_updates)

        new_state = IQLState(
            actor_params=new_actor_params,
            q_params=new_q_params,
            q_target_params=new_q_target_params,
            v_params=new_v_params,
            actor_opt_state=new_actor_opt_state,
            q_opt_state=new_q_opt_state,
            v_opt_state=new_v_opt_state,
            total_it=state.total_it + jnp.array(1, dtype=jnp.int32),
        )

        metrics = {
            "value_loss": v_loss,
            "q_loss": q_loss,
            "actor_loss": actor_loss,
        }
        return new_state, metrics, rng

    return train_step


# ==================== Evaluation ====================


def make_eval_policy_fn(actor_def, deterministic_actor: bool):
    """Create JIT-compiled evaluation policy."""

    @jax.jit
    def policy(params, obs_b1):
        if deterministic_actor:
            a = actor_def.apply({"params": params}, obs_b1, training=False)
        else:
            mean, std = actor_def.apply({"params": params}, obs_b1, training=False)
            a = mean
        return a

    return policy


def eval_actor(
    env: gym.Env,
    actor_def,
    actor_params,
    *,
    max_action: float,
    n_episodes: int,
    seed: int,
    deterministic_actor: bool,
):
    """Optimized evaluation: CPU params + JIT policy."""
    
    # 🔧 修改：仅在评估开始前设置一次种子，与 PyTorch 行为对齐
    # PyTorch 代码逻辑：env.seed(seed) -> run n_episodes
    try:
        env.seed(seed)
    except Exception:
        pass

    cpu = jax.devices("cpu")[0]
    actor_params_cpu = jax.device_put(actor_params, cpu)

    policy = make_eval_policy_fn(actor_def, deterministic_actor)

    episode_rewards = []

    for i in range(n_episodes):
        
        try:
            obs = env.reset()
        except Exception:
            pass
            
        done = False
        ep_ret = 0.0

        while not done:
            obs = jnp.asarray(obs, dtype=jnp.float32)[None, ...]
            a = policy(actor_params_cpu, obs)
            action = np.clip(np.asarray(a[0]) * max_action, -max_action, max_action)

            next_state, reward, done, _ = env.step(action)

            obs = next_state
            ep_ret += float(reward)

        episode_rewards.append(ep_ret)

    return np.asarray(episode_rewards, dtype=np.float32)



# ==================== Checkpoint Management ====================


def save_checkpoint(
    path: str,
    state: IQLState,
    eval_score: float,
    step: int,
    config: TrainConfig,
    state_mean: np.ndarray,
    state_std: np.ndarray,
):
    """Save checkpoint with metadata for SPAR."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    cpu_state = jax.device_get(state)
    cpu_state = jax.tree_map(
        lambda x: np.asarray(x) if hasattr(x, "dtype") else x, cpu_state
    )
    
    checkpoint = {
        "state": cpu_state,
        "eval_score": eval_score,
        "step": step,
        "config": {
            "env": config.env,
            "state_mean": np.asarray(state_mean),
            "state_std": np.asarray(state_std),
            "num_critics": config.num_critics,
            "iql_deterministic": config.iql_deterministic,
            "beta": config.beta,
            "iql_tau": config.iql_tau,
            "discount": config.discount,
        },
    }
    
    with open(path, "wb") as f:
        pickle.dump(checkpoint, f)


import sys
import pickle

def load_checkpoint(path: str) -> Tuple[IQLState, Dict[str, Any]]:
    """Load checkpoint and return state + metadata."""
    
    # 🔧 自定义 Unpickler，重定向模块路径
    class IQLUnpickler(pickle.Unpickler):
        def find_class(self, module, name):
            # 如果是从 __main__ 保存的 IQLState，重定向到 iql_jax
            if module == '__main__' and name == 'IQLState':
                module = 'iql_jax'
            return super().find_class(module, name)
    
    with open(path, "rb") as f:
        checkpoint = IQLUnpickler(f).load()
    
    state = jax.device_put(checkpoint["state"])
    metadata = {
        "eval_score": checkpoint.get("eval_score", None),
        "step": checkpoint.get("step", None),
        "config": checkpoint.get("config", {}),
    }
    return state, metadata


# ==================== Main Training Function ====================


@pyrallis.wrap()
def train(config: TrainConfig):
    # ===== Environment Setup =====
    env = gym.make(config.env)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)

    if config.normalize_reward:
        modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )

    env = wrap_env(env, state_mean=state_mean, state_std=state_std)

    replay_buffer = ReplayBuffer(state_dim, action_dim, config.buffer_size)
    replay_buffer.load_d4rl_dataset(dataset)

    max_action = float(env.action_space.high[0])

    # ===== Hierarchical Logging Setup =====
    hyperparams = {
        # Ensemble
        "num_critics": config.num_critics,
        
        # IQL Core
        "beta": config.beta,
        "iql_tau": config.iql_tau,
        
        # Learning Rates
        "actor_lr": config.actor_lr,
        "qf_lr": config.qf_lr,
        "vf_lr": config.vf_lr,
        
        # RL Basics
        "discount": config.discount,
        "tau": config.tau,
        "batch_size": config.batch_size,
        
        # Policy Type
        "iql_deterministic": config.iql_deterministic,
        "actor_dropout": config.actor_dropout,
        
        # Preprocessing
        "normalize": config.normalize,
        "normalize_reward": config.normalize_reward,
        
        # Training (🔧 新增)
        "max_timesteps": config.max_timesteps,
        "eval_freq": config.eval_freq,
        "n_episodes": config.n_episodes,
        "seed": config.seed,
    }
    
    # 使用修改后的 run_utils 进行初始化
    run_name, log_dir = wandb_init(
        config=config,
        log_base_dir=config.log_dir,
        env_name=config.env,
        method="IQL",
        hyperparams=hyperparams,
    )
    
    # ===== Checkpoint路径管理 =====
    if config.checkpoints_path is not None:
        checkpoints_dir = config.checkpoints_path
    elif log_dir is not None:
        checkpoints_dir = os.path.join(log_dir, "checkpoints")
    else:
        checkpoints_dir = os.path.join("checkpoints", config.env)
    
    os.makedirs(checkpoints_dir, exist_ok=True)
    
    # 始终保存配置文件
    config_save_path = os.path.join(checkpoints_dir, "config.yaml")
    with open(config_save_path, "w") as f:
        pyrallis.dump(config, f)
    
    print(f"💾 Checkpoints will be saved to: {checkpoints_dir}")

    # ===== Set Seed =====
    seed = config.seed
    set_seed(seed, env)

    # ===== Initialize Networks =====
    critic_def = Critic(state_dim=state_dim, action_dim=action_dim)
    v_def = ValueFunction(state_dim=state_dim)

    if config.iql_deterministic:
        actor_def = DeterministicPolicy(
            state_dim=state_dim,
            act_dim=action_dim,
            max_action=max_action,
            dropout=config.actor_dropout,
        )
    else:
        actor_def = GaussianPolicy(
            state_dim=state_dim,
            act_dim=action_dim,
            max_action=max_action,
            dropout=config.actor_dropout,
        )

    rng = jax.random.PRNGKey(seed)
    rng, actor_key, drop_key, v_key = jax.random.split(rng, 4)

    obs_dummy = jnp.zeros((1, state_dim), dtype=jnp.float32)
    act_dummy = jnp.zeros((1, action_dim), dtype=jnp.float32)

    # Actor initialization
    actor_vars = actor_def.init(
        {"params": actor_key, "dropout": drop_key}, obs_dummy, training=True
    )
    actor_params = actor_vars["params"]

    # ===== Ensemble Critic initialization using vmap =====
    print(f"🔧 Initializing {config.num_critics} critics using vmap...")
    
    # Generate N different random keys
    rng, subkey = jax.random.split(rng)
    q_keys = jax.random.split(subkey, config.num_critics)
    
    # Define initialization function for a single critic
    def init_single_critic(key):
        return critic_def.init({"params": key}, obs_dummy, act_dummy)["params"]
    
    # Use vmap to initialize all critics in parallel
    q_params = jax.vmap(init_single_critic)(q_keys)
    q_target_params = jax.tree_map(jnp.copy, q_params)
    
    print(f"✅ Ensemble critics initialized with shape: {jax.tree_map(lambda x: x.shape, q_params)}")

    # V initialization
    v_vars = v_def.init({"params": v_key}, obs_dummy)
    v_params = v_vars["params"]

    # ===== Initialize Optimizers =====
    v_tx = optax.adam(learning_rate=config.vf_lr)
    q_tx = optax.adam(learning_rate=config.qf_lr)

    actor_lr_schedule = optax.cosine_decay_schedule(
        init_value=config.actor_lr,
        decay_steps=config.max_timesteps,
        alpha=0.0,
    )
    actor_tx = optax.adam(learning_rate=actor_lr_schedule)

    v_opt_state = v_tx.init(v_params)
    q_opt_state = q_tx.init(q_params)
    actor_opt_state = actor_tx.init(actor_params)

    train_state = IQLState(
        actor_params=actor_params,
        q_params=q_params,
        q_target_params=q_target_params,
        v_params=v_params,
        actor_opt_state=actor_opt_state,
        q_opt_state=q_opt_state,
        v_opt_state=v_opt_state,
        total_it=jnp.array(0, dtype=jnp.int32),
    )

    print("=" * 80)
    print(f"🚀 Training IQL with Ensemble Critic (N={config.num_critics})")
    print(f"📊 Environment: {config.env}")
    print(f"🎲 Seed: {seed}")
    print(f"🔧 Beta: {config.beta}, IQL Tau: {config.iql_tau}")
    print(f"📈 Policy: {'Deterministic' if config.iql_deterministic else 'Gaussian'}")
    print(f"🎯 V target uses first 2 critics (TwinQ-style)")
    print("=" * 80)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        train_state, metadata = load_checkpoint(str(policy_file))
        print(f"✅ Loaded checkpoint from: {policy_file}")
        print(f"   Eval score: {metadata.get('eval_score', 'N/A')}")
        print(f"   Step: {metadata.get('step', 'N/A')}")

    # Create JIT-compiled training step
    train_step = make_train_step(
        actor_def,
        critic_def,
        v_def,
        actor_tx,
        q_tx,
        v_tx,
        discount=config.discount,
        tau=config.tau,
        beta=config.beta,
        iql_tau=config.iql_tau,
        deterministic_actor=config.iql_deterministic,
    )

    evaluations: List[float] = []
    best_eval_score = -np.inf


    # 🔥 新增：训练前先评测一次 (Step 0)
    print("📊 Running Initial Evaluation (Step 0)...")
    jax.block_until_ready(train_state.actor_params)
    init_eval_scores = eval_actor(
        env, actor_def, train_state.actor_params, 
        max_action=max_action, n_episodes=config.n_episodes, 
        seed=config.seed, deterministic_actor=config.iql_deterministic,
    )
    init_score = float(init_eval_scores.mean())
    init_norm_score = float(env.get_normalized_score(init_score) * 100.0)
    evaluations.append(init_norm_score)
    
    print(f"   Initial D4RL Score: {init_norm_score:.2f}")
    
    # 记录 Step 0 到 WandB
    wandb.log({
        "eval/episode_reward": init_score,
        "eval/episode_reward_std": float(init_eval_scores.std()),
        "eval/d4rl_normalized_score": init_norm_score,
        "eval/episode_rewards_distribution": wandb.Histogram(init_eval_scores),
    }, step=0)

    # ===== Training Loop =====
    pbar = tqdm(
        range(int(config.max_timesteps)),
        desc=f"Training IQL on {config.env}",
        dynamic_ncols=True,
    )

    for t in pbar:
        # Sample batch
        batch = replay_buffer.sample(config.batch_size)
        train_state, metrics, rng = train_step(train_state, batch, rng)

        # Log training metrics (替换 writer.add_scalar)
        if (t + 1) % 1000 == 0:
            metrics_np = jax.device_get(metrics)
            
            # WandB Log: 训练指标
            log_dict = {f"train/{k}": float(v) for k, v in metrics_np.items()}
            wandb.log(log_dict, step=int(train_state.total_it))

            pbar.set_postfix(
                {
                    "v_loss": f"{float(metrics_np['value_loss']):.4f}",
                    "q_loss": f"{float(metrics_np['q_loss']):.4f}",
                    "a_loss": f"{float(metrics_np['actor_loss']):.4f}",
                }
            )

        # Evaluate
        if (t + 1) % config.eval_freq == 0:
            jax.block_until_ready(train_state.actor_params)

            pbar.write(f"\n{'='*80}")
            pbar.write(f"⏱️  Time steps: {t + 1}")

            eval_scores = eval_actor(
                env, actor_def, train_state.actor_params, 
                max_action=max_action, n_episodes=config.n_episodes, 
                seed=config.seed, deterministic_actor=config.iql_deterministic,
            )
            eval_score = float(eval_scores.mean())
            eval_std = float(eval_scores.std())
            normalized_eval_score = float(env.get_normalized_score(eval_score) * 100.0)
            evaluations.append(normalized_eval_score)

            pbar.write(
                f"📊 Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} ± {eval_std:.3f}, D4RL score: {normalized_eval_score:.3f}"
            )

            # Check if this is the best model
            is_best = normalized_eval_score > best_eval_score
            if is_best:
                best_eval_score = normalized_eval_score
                pbar.write(f"🏆 New best score: {best_eval_score:.3f}")

            pbar.write(f"{'='*80}\n")

            wandb.log({
                "eval/episode_reward": eval_score,
                "eval/episode_reward_std": eval_std,
                "eval/d4rl_normalized_score": normalized_eval_score,
                # WandB 支持直接记录直方图
                "eval/episode_rewards_distribution": wandb.Histogram(eval_scores),
            }, step=int(train_state.total_it))

            pbar.set_postfix(
                {
                    "v_loss": f"{float(metrics_np['value_loss']):.4f}",
                    "q_loss": f"{float(metrics_np['q_loss']):.4f}",
                    "a_loss": f"{float(metrics_np['actor_loss']):.4f}",
                    "d4rl_score": f"{normalized_eval_score:.2f}",
                    "best": f"{best_eval_score:.2f}",
                }
            )

            # Save checkpoints
            if checkpoints_dir is not None:
                # Save periodic checkpoint
                ckpt_path = os.path.join(
                    checkpoints_dir, f"checkpoint_{t+1}.pkl"
                )
                save_checkpoint(
                    ckpt_path, train_state, normalized_eval_score, t + 1,
                    config, state_mean, state_std
                )

                # Save last checkpoint
                last_path = os.path.join(checkpoints_dir, "checkpoint_last.pkl")
                save_checkpoint(
                    last_path, train_state, normalized_eval_score, t + 1,
                    config, state_mean, state_std
                )

                # Save best checkpoint
                if is_best:
                    best_path = os.path.join(checkpoints_dir, "checkpoint_best.pkl")
                    save_checkpoint(
                        best_path, train_state, normalized_eval_score, t + 1,
                        config, state_mean, state_std
                    )
                    pbar.write(f"💾 Saved best checkpoint to: {best_path}")

    pbar.close()
    wandb.finish()

    print("\n" + "=" * 80)
    print("🎉 Training completed!")
    if len(evaluations) > 0:
        print(f"📈 Final D4RL score: {evaluations[-1]:.3f}")
        print(f"🏆 Best  D4RL score: {best_eval_score:.3f}")
    print("=" * 80)


if __name__ == "__main__":
    train()
