import dataclasses
import functools
import json
import logging
import math
import random
from pathlib import Path
from typing import Any, List, Iterator, Sequence
import jax
import jax.numpy as jnp
from flax import nnx
from flax import traverse_util
import numpy as np
import optax
import tyro
import tqdm

from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.shared import array_typing as at
from openpi.training import checkpoints as _checkpoints
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
from openpi.training import optimizer as _optimizer
from openpi.training import sharding as _sharding
from openpi.training import utils as training_utils
from openpi.shared import nnx_utils
from openpi.shared import path_utils
import openpi.transforms as _transforms


@dataclasses.dataclass(frozen=True)
class ComparatorOfflineTrainConfig(_config.TrainConfig):
    """使用离线构造好的优劣动作对进行比较器训练（样本级shuffle且与obs严格对齐）。

    需要 pairs_dir 下存在 build_action_pairs.py 生成的 shard_*.npz 与 manifest.json。
    """

    name: str = "pi05_comparator_offline_new"
    model: _model.BaseModelConfig = dataclasses.field(
        default_factory=lambda: pi0_config.Pi0Config(
            enable_action_comparator=True,
            pi05=True,
            action_horizon=10,
            discrete_state_input=False,
        )
    )

    # 数据集配置（用于加载Observation；可命令行覆盖）
    data: _config.DataConfigFactory = dataclasses.field(
        default_factory=lambda: _config.LeRobotLiberoDataConfig(
            repo_id="physical-intelligence/libero",
            base_config=_config.DataConfig(prompt_from_task=True),
            extra_delta_transform=False,
        )
    )

    # 离线pair目录
    pairs_dir: str = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_COMPARATOR_OFFLINE_NEW_PAIRS_DIR",
            "OPENPI_PI05_PAIRS_OFFLINE_NEW_DIR",
            "OPENPI_PAIRS_OFFLINE_NEW_DIR",
        )
    )

    # 阶段一 checkpoint 目录（包含 params 与 assets）。若未提供则从 pairs manifest 读取
    stage1_checkpoint_dir: str | None = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_COMPARATOR_STAGE1_DIR",
            "OPENPI_PAIRBUILD_PI05_STAGE1_DIR",
            "OPENPI_PI05_STAGE1_CHECKPOINT_DIR",
            "OPENPI_STAGE1_CHECKPOINT_DIR",
        )
        or None
    )

    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=20_000,
        peak_lr=1e-5,
        decay_steps=1_000_000,
        decay_lr=1e-5,
    )
    optimizer=_optimizer.AdamW(clip_gradient_norm=1.0)
    ema_decay=0.999
    num_train_steps=30_000
    # Debug: collect comparator feature scale stats
    debug_stats: bool = True


def init_train_state(config: ComparatorOfflineTrainConfig, mesh) -> tuple[training_utils.TrainState, Any, ComparatorOfflineTrainConfig]:
    key = jax.random.key(config.seed)
    model = config.model.create(key)
    # 仅训练 comparator 与 comparison_queries
    comparator_filter = nnx.Any(
        nnx_utils.PathRegex(".*action_comparator.*"),
        nnx_utils.PathRegex(".*comparison_queries.*"),
    )

    def freeze_filter(path, _):
        return not comparator_filter(path, _)

    train_config = dataclasses.replace(config, freeze_filter=freeze_filter)

    params = nnx.state(model)
    tx = _optimizer.create_optimizer(train_config.optimizer, train_config.lr_schedule)
    train_state = training_utils.TrainState(
        step=jnp.array(0),
        params=params,
        model_def=nnx.split(model)[0],
        opt_state=tx.init(params.filter(train_config.trainable_filter)),
        tx=tx,
        ema_decay=train_config.ema_decay,
        # 注意避免与 params 共享底层 buffer，防止 donation 冲突
        ema_params=(jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else (x.copy() if isinstance(x, np.ndarray) else x), params) if train_config.ema_decay is not None else None),
    )
    return train_state, None, train_config


def bce_with_logits(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    """数值稳定的 BCE（直接用 logits 计算）。

    loss = max(x, 0) - x*y + log(1 + exp(-|x|))
    """
    x = logits.astype(jnp.float32)
    y = labels.astype(jnp.float32)
    return jnp.maximum(x, 0.0) - x * y + jnp.log1p(jnp.exp(-jnp.abs(x)))


def train_step(config: ComparatorOfflineTrainConfig, rng, state: training_utils.TrainState, batch):
    model = nnx.merge(state.model_def, state.params)
    model.train()
    observation, action_a, action_b, label = batch

    @at.typecheck
    def loss_fn(model: _model.BaseModel, rng: at.KeyArrayLike):
        context = model._prefill_vlm_with_queries(observation)
        if getattr(config, "debug_stats", False):
            logits_stats = model.compare_actions_with_context(observation, context, action_a, action_b, return_stats=True)
            logits, comp_stats = logits_stats
            logits = logits.squeeze(-1)
        else:
            logits = model.compare_actions_with_context(observation, context, action_a, action_b).squeeze(-1)
        # 设备内平均
        loss_local = jnp.mean(bce_with_logits(logits, label))
        # 参考 scripts/train.py：跨卡聚合采用 host 端聚合（common_utils.stack_forest + jnp.mean）
        # 这里返回各卡的标量 loss，主机端在日志窗口聚合
        loss = loss_local

        # VLM 特征健康度统计
        try:
            raw_layers = jnp.stack([jnp.asarray(x) for x in context.get("raw_features_by_layer", [])], axis=0)
            core_layers = jnp.stack([jnp.asarray(x) for x in context.get("core_features_by_layer", [])], axis=0)
            raw_finite_frac = jnp.mean(jnp.isfinite(raw_layers).astype(jnp.float32))
            core_finite_frac = jnp.mean(jnp.isfinite(core_layers).astype(jnp.float32))
            raw_min = jnp.min(raw_layers)
            raw_max = jnp.max(raw_layers)
            core_min = jnp.min(core_layers)
            core_max = jnp.max(core_layers)
        except Exception:
            raw_finite_frac = jnp.array(1.0, dtype=jnp.float32)
            core_finite_frac = jnp.array(1.0, dtype=jnp.float32)
            raw_min = jnp.array(0.0, dtype=jnp.float32)
            raw_max = jnp.array(0.0, dtype=jnp.float32)
            core_min = jnp.array(0.0, dtype=jnp.float32)
            core_max = jnp.array(0.0, dtype=jnp.float32)

        # 动作/状态健康度
        state_first_dims = getattr(config.model, "comparator_state_dim", 8)
        s_head = observation.state[..., :state_first_dims]
        aux = {
            "logits_min": jnp.min(logits),
            "logits_max": jnp.max(logits),
            "logits_mean": jnp.mean(logits),
            "logits_finite_frac": jnp.mean(jnp.isfinite(logits).astype(jnp.float32)),
            "label_min": jnp.min(label),
            "label_max": jnp.max(label),
            "label_mean": jnp.mean(label),
            # extra
            "vlm_raw_finite_frac": raw_finite_frac,
            "vlm_core_finite_frac": core_finite_frac,
            "vlm_raw_min": raw_min,
            "vlm_raw_max": raw_max,
            "vlm_core_min": core_min,
            "vlm_core_max": core_max,
            "action_a_finite_frac": jnp.mean(jnp.isfinite(action_a).astype(jnp.float32)),
            "action_b_finite_frac": jnp.mean(jnp.isfinite(action_b).astype(jnp.float32)),
            "state_finite_frac": jnp.mean(jnp.isfinite(s_head).astype(jnp.float32)),
            "action_a_min": jnp.min(action_a),
            "action_a_max": jnp.max(action_a),
            "action_b_min": jnp.min(action_b),
            "action_b_max": jnp.max(action_b),
        }
        if getattr(config, "debug_stats", False):
            # 仅收集少量关键统计，避免树过大
            keys = [
                "in_a_flat_rms","in_b_flat_rms","in_d_flat_rms",
                "proj_a_rms","proj_b_rms","proj_d_rms","proj_s_rms",
                "summary_rms","logits_rms","logits_abs_max",
            ]
            for k in keys:
                if k in comp_stats:
                    aux[f"comp_{k}"] = comp_stats[k]
            # 附加前若干层的中间尺度（可按需扩展层数）
            for l in range(17):
                k1 = f"l{l}_x_attn_rms"
                k2 = f"l{l}_x_after_block_rms"
                if k1 in comp_stats:
                    aux[f"comp_{k1}"] = comp_stats[k1]
                if k2 in comp_stats:
                    aux[f"comp_{k2}"] = comp_stats[k2]
        return loss, aux

    train_rng = jax.random.fold_in(rng, state.step)
    diff_state = nnx.DiffState(0, config.trainable_filter)
    (loss, aux), grads = nnx.value_and_grad(loss_fn, argnums=diff_state, has_aux=True)(model, train_rng)

    params = state.params.filter(config.trainable_filter)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
    new_params = optax.apply_updates(params, updates)
    nnx.update(model, new_params)
    new_params = nnx.state(model)

    new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
    if state.ema_decay is not None:
        new_state = dataclasses.replace(
            new_state,
            ema_params=jax.tree.map(lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params),
        )

    info = {"loss": loss} | aux
    return new_state, info


class PairwiseSampleIndex:
    """将样本级索引映射到(shard_path, local_idx)。"""

    def __init__(self, pairs_dir: Path):
        self._pairs_dir = pairs_dir
        mf = json.loads((pairs_dir / "manifest.json").read_text())
        self._shards = mf.get("shards", [])
        self.total = mf.get("total_samples", 0)
        # 前缀和
        self._prefix: List[int] = []
        acc = 0
        for s in self._shards:
            length = int(s["length"])
            acc += length
            self._prefix.append(acc)

    def locate(self, global_idx: int) -> tuple[Path, int]:
        lo, hi = 0, len(self._prefix) - 1
        while lo <= hi:
            mid = (lo + hi) // 2
            if global_idx < self._prefix[mid]:
                hi = mid - 1
            else:
                lo = mid + 1
        shard_idx = lo
        prev_acc = 0 if shard_idx == 0 else self._prefix[shard_idx - 1]
        local_idx = global_idx - prev_acc
        shard = self._shards[shard_idx]
        return self._pairs_dir / shard["path"], local_idx


def _load_npz_indices(path: Path, idxs: np.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, np.ndarray]:
    with np.load(path) as data:
        a = data["action_a"][idxs]
        b = data["action_b"][idxs]
        y = data["label"][idxs]
        ds_idx = data["ds_idx"][idxs]
    return jnp.asarray(a, dtype=jnp.float32), jnp.asarray(b, dtype=jnp.float32), jnp.asarray(y, dtype=jnp.float32), np.asarray(ds_idx, dtype=np.int64)


def _pairwise_batch_iter(pairs_dir: Path, batch_size: int, shuffle: bool, seed: int) -> Iterator[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, np.ndarray]]:
    index = PairwiseSampleIndex(pairs_dir)
    order = list(range(index.total))
    rng = random.Random(seed)
    while True:
        if shuffle:
            rng.shuffle(order)
        for bi in range(0, index.total - index.total % batch_size, batch_size):
            batch_ids = order[bi : bi + batch_size]
            by_shard: dict[Path, List[int]] = {}
            for gid in batch_ids:
                spath, lidx = index.locate(gid)
                by_shard.setdefault(spath, []).append(lidx)
            action_a_list: List[jnp.ndarray] = []
            action_b_list: List[jnp.ndarray] = []
            label_list: List[jnp.ndarray] = []
            ds_list: List[np.ndarray] = []
            for spath, local_list in by_shard.items():
                local_idxs = np.asarray(local_list, dtype=np.int64)
                a, b, y, ds_idx = _load_npz_indices(spath, local_idxs)
                action_a_list.append(a)
                action_b_list.append(b)
                label_list.append(y)
                ds_list.append(ds_idx)
            yield (
                jnp.concatenate(action_a_list, axis=0),
                jnp.concatenate(action_b_list, axis=0),
                jnp.concatenate(label_list, axis=0),
                np.concatenate(ds_list, axis=0),
            )


def _collate_samples(items: Sequence[dict]) -> dict:
    return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)


def main(config: ComparatorOfflineTrainConfig):
    logging.basicConfig(level=logging.INFO, force=True)
    if config.batch_size % jax.device_count() != 0:
        raise ValueError(
            f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
        )
    mesh = _sharding.make_mesh(config.fsdp_devices)
    data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(_sharding.DATA_AXIS))
    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    train_state, _, train_config = init_train_state(config, mesh)
    # Ensure initial train state is materialized
    train_state = jax.block_until_ready(train_state)
    logging.info("Initial train state created and materialized")

    pairs_env_hint = (
        "OPENPI_COMPARATOR_OFFLINE_NEW_PAIRS_DIR",
        "OPENPI_PI05_PAIRS_OFFLINE_NEW_DIR",
        "OPENPI_PAIRS_OFFLINE_NEW_DIR",
    )
    pairs_dir_str = path_utils.require_path(
        train_config.pairs_dir,
        description="pairs_dir",
        env_vars=pairs_env_hint,
        cli_flag="--pairs-dir",
    )
    train_config = dataclasses.replace(train_config, pairs_dir=pairs_dir_str)

    # 读取pairs manifest
    manifest_path = Path(train_config.pairs_dir) / "manifest.json"
    manifest = json.loads(manifest_path.read_text())
    total_samples = int(manifest.get("total_samples", 0))
    if total_samples <= 0:
        raise ValueError("pairs manifest missing or invalid total_samples")
    # 一致性校验：shards 累计长度应与 total_samples 一致
    shards = manifest.get("shards", [])
    if not isinstance(shards, list) or len(shards) == 0:
        raise ValueError("pairs manifest missing shards list")
    shards_len = sum(int(s.get("length", 0)) for s in shards)
    if shards_len != total_samples:
        raise ValueError(f"pairs manifest inconsistent: sum(shards.length)={shards_len} != total_samples={total_samples}")

    # 加载阶段一 checkpoint（params + norm_stats）
    ckpt_dir = train_config.stage1_checkpoint_dir or manifest.get("stage1_checkpoint_dir")
    if ckpt_dir:
        logging.info(f"Loading stage-1 parameters from: {ckpt_dir}")
        params = _model.restore_params(Path(ckpt_dir) / "params", dtype=jnp.bfloat16)
        # 合并 checkpoint 权重到已初始化参数，仅覆盖重合路径，
        # 保留比较器等新增模块的已初始化实值，避免 ShapeDtypeStruct 进入执行流。
        base_state = train_state.params
        base_pure = base_state.to_pure_dict()
        flat_base = traverse_util.flatten_dict(base_pure)
        flat_loaded = traverse_util.flatten_dict(params)
        for kp, v in flat_loaded.items():
            if kp in flat_base:
                flat_base[kp] = v
        merged = traverse_util.unflatten_dict(flat_base)
        base_state.replace_by_pure_dict(merged)
        # 重建优化器状态
        opt_state = train_state.tx.init(base_state.filter(train_config.trainable_filter))
        train_state = dataclasses.replace(
            train_state,
            params=base_state,
            opt_state=opt_state,
            # 深拷贝一份，避免与 params buffer 共享导致 donation 冲突
            ema_params=(jax.tree.map(lambda x: jnp.copy(x) if isinstance(x, jax.Array) else (x.copy() if isinstance(x, np.ndarray) else x), base_state) if train_config.ema_decay is not None else None),
        )
        # Ensure all arrays are materialized
        train_state = jax.block_until_ready(train_state)
        logging.info("Stage-1 checkpoint loaded and materialized")

    # 原始Observation数据集（按ds_idx精确索引），并准备 Normalize 统计
    data_cfg = train_config.data.create(train_config.assets_dirs, train_config.model)
    dataset = _data_loader.create_torch_dataset(data_cfg, train_config.model.action_horizon, train_config.model)
    if data_cfg.asset_id is None:
        raise ValueError("Asset id is required to load norm stats.")
    norm_stats = _checkpoints.load_norm_stats(Path(ckpt_dir) / "assets", data_cfg.asset_id) if ckpt_dir else None
    dataset = _data_loader.transform_dataset(dataset, data_cfg, skip_norm_stats=True)
    # 一致性校验：原始数据集长度应与 manifest.total_samples 一致
    dataset_len = len(dataset)
    if dataset_len != total_samples:
        raise ValueError(f"dataset length {dataset_len} != manifest total_samples {total_samples}")

    pair_iter = _pairwise_batch_iter(Path(train_config.pairs_dir), train_config.batch_size, shuffle=True, seed=train_config.seed)

    ckpt_mngr, resuming = _checkpoints.initialize_checkpoint_dir(
        train_config.checkpoint_dir, keep_period=train_config.keep_period, overwrite=train_config.overwrite, resume=train_config.resume
    )
    if resuming:
        dummy_loader = _data_loader.DataLoaderImpl(data_cfg, None)  # type: ignore[arg-type]
        train_state = _checkpoints.restore_state(ckpt_mngr, train_state, dummy_loader)

    # 确定恢复起始步数（若 --resume）
    start_step = int(jax.device_get(train_state.step))
    if start_step >= train_config.num_train_steps:
        logging.info(
            f"Training already completed: start_step={start_step} >= num_train_steps={train_config.num_train_steps}"
        )
        return
    if resuming and start_step > 0:
        logging.info(f"Resuming training from step {start_step}")

    pbar = tqdm.tqdm(
        range(start_step, train_config.num_train_steps),
        total=train_config.num_train_steps,
        initial=start_step,
        dynamic_ncols=True,
    )
    infos = []
    # 参考 scripts/train.py：使用 jax.jit，并在调用时绑定 mesh 上下文
    prun = jax.jit(
        functools.partial(train_step, train_config),
        in_shardings=(replicated_sharding, replicated_sharding, data_sharding),
        out_shardings=(replicated_sharding, replicated_sharding),
        # 仅捐赠 train_state，且确保其内部与 ema_params 不共享 buffer
        donate_argnums=(1,),
    )

    for step in pbar:
        rng = jax.random.key(train_config.seed + step)
        action_a, action_b, label, ds_idx = next(pair_iter)
        # 按ds_idx从原始dataset精确取对应样本，构造 Observation batch
        batch_items = [dataset[int(k)] for k in ds_idx.tolist()]
        obs_dict = _collate_samples(batch_items)
        # token dtype 修正 + Normalize state（仅前若干维；跳过 obs 内的 actions）
        if "tokenized_prompt" in obs_dict:
            obs_dict["tokenized_prompt"] = jnp.asarray(obs_dict["tokenized_prompt"], dtype=jnp.int32)
        if "tokenized_prompt_mask" in obs_dict:
            obs_dict["tokenized_prompt_mask"] = jnp.asarray(obs_dict["tokenized_prompt_mask"], dtype=jnp.bool_)
        if norm_stats is not None:
            # 训练比较器时，obs 不需要包含 actions；避免对 32 维填充后的 actions 做归一化（assets 仅有 7 维统计）。
            obs_dict.pop("actions", None)
            # 仅对 state 的前 comparator_state_dim 维做 Normalize，尾部填充值保持不变。
            if "state" in obs_dict and isinstance(norm_stats, dict) and ("state" in norm_stats):
                state_first_dims = getattr(train_config.model, "comparator_state_dim", 8)
                s = obs_dict["state"]
                s_first = s[..., :state_first_dims]
                normalize_state_front = _transforms.Normalize({"state": norm_stats["state"]}, use_quantiles=data_cfg.use_quantile_norm)
                s_first_norm = normalize_state_front({"state": s_first})["state"]
                obs_dict["state"] = np.concatenate([np.asarray(s_first_norm), np.asarray(s[..., state_first_dims:])], axis=-1)
        observation = _model.Observation.from_dict(obs_dict)

        # Normalize pair actions at training time
        if norm_stats is not None:
            norm = _transforms.Normalize(norm_stats, use_quantiles=data_cfg.use_quantile_norm)
            action_a_np = np.asarray(action_a)
            action_b_np = np.asarray(action_b)
            action_a = jnp.asarray(norm({"actions": action_a_np})["actions"], dtype=jnp.float32)
            action_b = jnp.asarray(norm({"actions": action_b_np})["actions"], dtype=jnp.float32)

        # 多卡分发
        def _shard(x):
            return jax.make_array_from_process_local_data(data_sharding, jnp.asarray(x))
        sharded_observation = _model.Observation.from_dict(jax.tree.map(_shard, observation.to_dict()))
        sharded_action_a = _shard(action_a)
        sharded_action_b = _shard(action_b)
        sharded_label = _shard(label)

        batch = (sharded_observation, sharded_action_a, sharded_action_b, sharded_label)
        # 绑定 JAX Mesh 上下文，确保 ('batch','fsdp') 轴可用于 pmean
        with mesh, _sharding.set_mesh(mesh):
            train_state, info = prun(rng, train_state, batch)
        infos.append(info)
        if step % train_config.log_interval == 0:
            # Host-side调试：仅在归一化可用时打印原始动作范围
            if norm_stats is not None:
                a_sample = np.asarray(action_a_np[:4])
                b_sample = np.asarray(action_b_np[:4])
                diff_sample = a_sample - b_sample
                pbar.write(
                    f"pairs dbg: a[min,max]={a_sample.min():.3f},{a_sample.max():.3f} "
                    f"b[min,max]={b_sample.min():.3f},{b_sample.max():.3f} "
                    f"diff[min,max]={diff_sample.min():.3f},{diff_sample.max():.3f}"
                )
            window = infos[-train_config.log_interval:]
            reduced = jax.device_get(
                jax.tree.map(lambda *xs: jnp.mean(jnp.stack(xs)), *window)
            )
            # 关键健康度指标
            keys_main = [
                "loss","logits_mean","logits_min","logits_max","logits_finite_frac",
                "vlm_raw_finite_frac","vlm_core_finite_frac","state_finite_frac",
                "action_a_finite_frac","action_b_finite_frac",
            ]
            if train_config.debug_stats:
                keys_main += [
                    "comp_in_a_flat_rms","comp_in_b_flat_rms","comp_in_d_flat_rms",
                    "comp_proj_a_rms","comp_proj_b_rms","comp_proj_d_rms","comp_proj_s_rms",
                    "comp_summary_rms","comp_logits_rms","comp_logits_abs_max",
                ]
                # 打印前6层的中间尺度，观察随深度的变化
                for l in range(17):
                    keys_main += [f"comp_l{l}_x_attn_rms", f"comp_l{l}_x_after_block_rms"]
            msg = ", ".join(
                f"{k}={float(reduced[k]):.4f}" for k in keys_main if k in reduced
            )
            pbar.write(f"Step {step}: {msg}")
            # 若 VLM 特征存在非有限值，补充范围
            if float(reduced.get("vlm_raw_finite_frac", 1.0)) < 1.0 or float(reduced.get("vlm_core_finite_frac", 1.0)) < 1.0:
                pbar.write(
                    f"vlm_raw[min,max]={float(reduced.get('vlm_raw_min', 0.0)):.4f},{float(reduced.get('vlm_raw_max', 0.0)):.4f} "
                    f"vlm_core[min,max]={float(reduced.get('vlm_core_min', 0.0)):.4f},{float(reduced.get('vlm_core_max', 0.0)):.4f}"
                )
        if (step % train_config.save_interval == 0 and step > 0) or step == train_config.num_train_steps - 1:
            _checkpoints.save_state(ckpt_mngr, train_state, _data_loader.DataLoaderImpl(data_cfg, None), step)  # type: ignore[arg-type]


if __name__ == "__main__":
    main(tyro.cli(ComparatorOfflineTrainConfig))


