import dataclasses
import json
import logging
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, List, Optional

import jax
import jax.numpy as jnp
import numpy as np
import tyro
import tqdm

from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
from openpi.training import checkpoints as _checkpoints
import openpi.transforms as _transforms
from openpi.shared import nnx_utils
from openpi.shared import path_utils


@dataclasses.dataclass(frozen=True)
class PairBuildConfig(_config.TrainConfig):
    """离线构造优劣动作对，一次性遍历原始数据集并保存到磁盘。

    规则：对每个状态仅构造一个动作对；正负标签与优动作来源（GT或good-sampled）均随机。
    """

    name: str = "pi05_pairbuild"
    model: _model.BaseModelConfig = dataclasses.field(
        default_factory=lambda: pi0_config.Pi0Config(
            enable_action_comparator=True,  # 采样不依赖比较器，但配置允许存在
            pi05=True,
            action_horizon=10,
            discrete_state_input=False,
        )
    )

    # 数据集配置（默认与在线训练脚本一致，可通过命令行覆盖）
    data: _config.DataConfigFactory = dataclasses.field(
        default_factory=lambda: _config.LeRobotLiberoDataConfig(
            repo_id=path_utils.env_path(
                "OPENPI_LIBERO_ONLINE_REPO_ID",
                default="physical-intelligence/libero",
            ),
            base_config=_config.DataConfig(prompt_from_task=True),
            extra_delta_transform=False,
            
        )
    )
    

    # 用于初始化采样模型与归一化（stage-1 checkpoint dir，含 params 与 assets）
    stage1_checkpoint_dir: str = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_PAIRBUILD_ONLINE_STAGE1_DIR",
            "OPENPI_PAIRBUILD_PI05_STAGE1_DIR",
            "OPENPI_PI05_STAGE1_CHECKPOINT_DIR",
            "OPENPI_STAGE1_CHECKPOINT_DIR",
        )
    )

    # 采样步数设定：good 与 bad
    good_steps: int = 10
    bad_steps: int = 3

    # 输出目录
    output_dir: str = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_PAIRBUILD_ONLINE_OUTPUT_DIR",
            "OPENPI_PI05_PAIRS_ONLINE_DIR",
            "OPENPI_PAIRS_ONLINE_DIR",
        )
    )

    # 是否跳过归一化统计（删除：强制使用 stage1 的 norm stats）
    # 保留字段仅为向后兼容，不再生效
    skip_norm_stats: bool = False

    # 生成阶段内部的采样批量（仅用于加速构造，不影响存储的样本级索引）
    gen_batch_size: int = 32

    # 每个分片保存的样本数（样本级存储，允许训练端shuffle）
    samples_per_shard: int = 8192

    # 并行相关（父进程自动分发或作为子进程执行）
    num_workers: int = 4
    worker_id: Optional[int] = None  # 父进程为空；子进程会设置具体id
    device: Optional[str] = None  # 子进程指定使用的单个设备，如 "cuda:0"，父进程留空
    range_start: Optional[int] = None  # 子进程处理的起始全局样本索引（含）
    range_end: Optional[int] = None  # 子进程处理的结束全局样本索引（不含）


_STAGE1_ENV_HINT = (
    "OPENPI_PAIRBUILD_ONLINE_STAGE1_DIR",
    "OPENPI_PAIRBUILD_PI05_STAGE1_DIR",
    "OPENPI_PI05_STAGE1_CHECKPOINT_DIR",
    "OPENPI_STAGE1_CHECKPOINT_DIR",
)
_OUTPUT_ENV_HINT = (
    "OPENPI_PAIRBUILD_ONLINE_OUTPUT_DIR",
    "OPENPI_PI05_PAIRS_ONLINE_DIR",
    "OPENPI_PAIRS_ONLINE_DIR",
)


def _resolve_stage1_dir(cfg: PairBuildConfig) -> Path:
    return Path(
        path_utils.require_path(
            cfg.stage1_checkpoint_dir,
            description="stage-1 checkpoint 目录",
            env_vars=_STAGE1_ENV_HINT,
            cli_flag="--stage1-checkpoint-dir",
        )
    )


def _resolve_output_dir(cfg: PairBuildConfig) -> Path:
    return Path(
        path_utils.require_path(
            cfg.output_dir,
            description="动作对输出目录",
            env_vars=_OUTPUT_ENV_HINT,
            cli_flag="--output-dir",
        )
    )


def _create_dataset_and_stats(cfg: PairBuildConfig):
    stage1_dir = _resolve_stage1_dir(cfg)
    data_cfg = cfg.data.create(cfg.assets_dirs, cfg.model)
    # 从 stage-1 checkpoint 加载与训练一致的 norm_stats
    if data_cfg.asset_id is None:
        raise ValueError("Asset id is required to load norm stats.")
    norm_stats = _checkpoints.load_norm_stats(stage1_dir / "assets", data_cfg.asset_id)
    dataset = _data_loader.create_torch_dataset(data_cfg, cfg.model.action_horizon, cfg.model)
    # 先应用 repack/data/model 变换，但不执行 Normalize；稍后在批处理处用来自 checkpoint 的 norm_stats 执行 Normalize
    dataset = _data_loader.transform_dataset(dataset, data_cfg, skip_norm_stats=True)
    return data_cfg, norm_stats, dataset


def _maybe_mkdir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _worker_main(cfg: PairBuildConfig):
    logging.basicConfig(level=logging.INFO, force=True)
    stage1_dir = _resolve_stage1_dir(cfg)
    out_dir = _resolve_output_dir(cfg)
    stage1_dir_str = str(stage1_dir)

    # 限定可见设备（子进程）
    if cfg.device is not None:
        if cfg.device.startswith("cuda:"):
            dev_id = cfg.device.split(":")[1]
            os.environ["CUDA_VISIBLE_DEVICES"] = dev_id
            logging.info(f"Worker {cfg.worker_id} uses device {cfg.device}")
        elif cfg.device == "cpu":
            os.environ["CUDA_VISIBLE_DEVICES"] = ""

    # 构造样本级数据集 + 与训练一致的 norm_stats
    data_cfg, norm_stats, dataset = _create_dataset_and_stats(cfg)
    total_samples = len(dataset)
    # 计算工作区间
    range_start = cfg.range_start if cfg.range_start is not None else 0
    range_end = cfg.range_end if cfg.range_end is not None else total_samples
    if not (0 <= range_start <= range_end <= total_samples):
        raise ValueError(f"Invalid range: [{range_start}, {range_end}) with total={total_samples}")
    local_count = range_end - range_start
    logging.info(f"Worker {cfg.worker_id} range=[{range_start}, {range_end}), local_count={local_count}, gen_batch_size={cfg.gen_batch_size}")

    # 采样模型（从 stage-1 checkpoint 加载 params）
    params = _model.restore_params(stage1_dir / "params", dtype=jnp.bfloat16)
    model_for_sampling = cfg.model.load(params)

    # 使用与 policy 一致的 jit 包装以获得稳定吞吐
    sample_actions = nnx_utils.module_jit(model_for_sampling.sample_actions)

    # 输出准备
    _maybe_mkdir(out_dir)

    # 保存manifest，便于训练端对齐
    manifest: dict[str, Any] = {
        "good_steps": cfg.good_steps,
        "bad_steps": cfg.bad_steps,
        "action_horizon": cfg.model.action_horizon,
        "action_dim": cfg.model.action_dim,
        "stage1_checkpoint_dir": stage1_dir_str,
        "data_repo_id": cfg.data.repo_id,
        "data_prompt_from_task": getattr(cfg.data.base_config, "prompt_from_task", False),
        "total_samples": total_samples,
        "samples_per_shard": cfg.samples_per_shard,
        "shards": [],  # 将在写入时填充
    }
    worker_manifest = out_dir / (f"manifest_worker_{cfg.worker_id}.json" if cfg.worker_id is not None else "manifest_worker.json")
    worker_manifest.write_text(json.dumps(manifest, indent=2))

    rng = jax.random.key(cfg.seed)

    def _collate_fn(items: List[dict]) -> dict:
        return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)

    buffer_a: List[np.ndarray] = []
    buffer_b: List[np.ndarray] = []
    buffer_y: List[np.ndarray] = []
    buffer_idx: List[np.ndarray] = []
    shard_idx = 0
    total_written_global = 0

    # 维度配置（集中管理）
    state_first_dims = cfg.model.comparator_state_dim
    action_first_dims = cfg.model.comparator_action_dim

    # 仅对 actions 使用的反归一化器（严格匹配，仅含 actions，避免 state 触发 strict 校验）
    unnorm_actions = _transforms.Unnormalize({"actions": norm_stats["actions"]}, use_quantiles=data_cfg.use_quantile_norm)
    # 仅对 state 前若干维使用的归一化器
    normalize_state_front = _transforms.Normalize({"state": norm_stats["state"]}, use_quantiles=data_cfg.use_quantile_norm)

    # JIT 预热：拿一个样本
    if local_count > 0:
        warm_item = dataset[range_start]
        warm_batch = jax.tree.map(lambda x: np.asarray(x)[None, ...], warm_item)
        if "tokenized_prompt" in warm_batch:
            warm_batch["tokenized_prompt"] = jnp.asarray(warm_batch["tokenized_prompt"], dtype=jnp.int32)
        if "tokenized_prompt_mask" in warm_batch:
            warm_batch["tokenized_prompt_mask"] = jnp.asarray(warm_batch["tokenized_prompt_mask"], dtype=jnp.bool_)
        # 仅对 state 的前7维做 Normalize，避免与32维填充维度/动作维度冲突
        if "state" in warm_batch:
            s = warm_batch["state"]
            s_first = s[..., :state_first_dims]
            s_first_norm = normalize_state_front({"state": s_first})["state"]
            warm_batch["state"] = np.concatenate([np.asarray(s_first_norm), np.asarray(s[..., state_first_dims:])], axis=-1)
        warm_obs = _model.Observation.from_dict(warm_batch)
        _ = sample_actions(jax.random.key(cfg.seed), warm_obs, num_steps=cfg.bad_steps)
        jax.block_until_ready(_)

    show_pbar = (cfg.worker_id in (None, 0)) and (local_count > 0)
    pbar = tqdm.tqdm(total=local_count, dynamic_ncols=True, desc=f"worker {cfg.worker_id}") if show_pbar else None
    i = range_start
    while i < range_end:
        j = min(i + cfg.gen_batch_size, range_end)
        prev_i = i
        batch_items = [dataset[k] for k in range(i, j)]
        batch = _collate_fn(batch_items)
        if "tokenized_prompt" in batch:
            batch["tokenized_prompt"] = jnp.asarray(batch["tokenized_prompt"], dtype=jnp.int32)
        if "tokenized_prompt_mask" in batch:
            batch["tokenized_prompt_mask"] = jnp.asarray(batch["tokenized_prompt_mask"], dtype=jnp.bool_)
        # Normalize: 仅对 state 的前若干维进行归一化；图像已在 [-1,1]，动作不参与采样输入
        if "state" in batch:
            s = batch["state"]
            s_first = s[..., :state_first_dims]
            s_first_norm = normalize_state_front({"state": s_first})["state"]
            batch["state"] = np.concatenate([np.asarray(s_first_norm), np.asarray(s[..., state_first_dims:])], axis=-1)
        observation = _model.Observation.from_dict(batch)
        gt_actions = batch["actions"]  # 已为真实尺度，32维填充；下游取前若干维

        # 方案：以 50% 概率选择组合 (good=10, bad=3) 或 (good=9, bad=2)
        rng, r_good10, r_good9, r_bad3, r_bad2, r_combo, r_coin_flip = jax.random.split(rng, 7)
        good_10 = sample_actions(r_good10, observation, num_steps=10)
        good_9 = sample_actions(r_good9, observation, num_steps=9)
        bad_3 = sample_actions(r_bad3, observation, num_steps=3)
        bad_2 = sample_actions(r_bad2, observation, num_steps=2)

        bs = gt_actions.shape[0]
        choose_combo = jax.random.bernoulli(r_combo, p=0.5, shape=(bs,))  # True: (10,3); False: (9,2)

        # 仅使用前若干维进行存储与比较，并进行反归一化
        good_10_front_un = unnorm_actions({"actions": np.asarray(good_10[..., :action_first_dims])})["actions"]
        good_9_front_un = unnorm_actions({"actions": np.asarray(good_9[..., :action_first_dims])})["actions"]
        bad_3_front_un = unnorm_actions({"actions": np.asarray(bad_3[..., :action_first_dims])})["actions"]
        bad_2_front_un = unnorm_actions({"actions": np.asarray(bad_2[..., :action_first_dims])})["actions"]

        positive7 = jnp.where(choose_combo[:, None, None], good_10_front_un, good_9_front_un)
        negative7 = jnp.where(choose_combo[:, None, None], bad_3_front_un, bad_2_front_un)

        flip = jax.random.bernoulli(r_coin_flip, p=0.5, shape=(bs,))
        action_a = jnp.where(flip[:, None, None], negative7, positive7)
        action_b = jnp.where(flip[:, None, None], positive7, negative7)
        label = jnp.where(flip, 0, 1).astype(jnp.float32)

        buffer_a.append(np.asarray(action_a, dtype=np.float32))
        buffer_b.append(np.asarray(action_b, dtype=np.float32))
        buffer_y.append(np.asarray(label, dtype=np.float32))
        buffer_idx.append(np.arange(i, j, dtype=np.int64))

        total_len = sum(x.shape[0] for x in buffer_a)
        while total_len >= cfg.samples_per_shard:
            concat_a = np.concatenate(buffer_a, axis=0)
            concat_b = np.concatenate(buffer_b, axis=0)
            concat_y = np.concatenate(buffer_y, axis=0)
            concat_idx = np.concatenate(buffer_idx, axis=0)

            shard_a = concat_a[: cfg.samples_per_shard]
            shard_b = concat_b[: cfg.samples_per_shard]
            shard_y = concat_y[: cfg.samples_per_shard]
            shard_idx_arr = concat_idx[: cfg.samples_per_shard]

            shard_path = out_dir / f"shard_w{cfg.worker_id or 0}_{shard_idx:06d}.npz"
            np.savez_compressed(shard_path, action_a=shard_a, action_b=shard_b, label=shard_y, ds_idx=shard_idx_arr)

            mf = json.loads(worker_manifest.read_text())
            mf["shards"].append({
                "path": shard_path.name,
                "start": (range_start + total_written_global),
                "length": cfg.samples_per_shard,
            })
            worker_manifest.write_text(json.dumps(mf, indent=2))

            shard_idx += 1
            prev_total_written = total_written_global
            total_written_global += cfg.samples_per_shard
            logging.info(
                f"Saved shard {shard_idx} [{range_start + prev_total_written}, {range_start + total_written_global}) -> {shard_path}"
            )

            remain_a = concat_a[cfg.samples_per_shard :]
            remain_b = concat_b[cfg.samples_per_shard :]
            remain_y = concat_y[cfg.samples_per_shard :]
            remain_idx = concat_idx[cfg.samples_per_shard :]
            buffer_a = [remain_a] if remain_a.shape[0] > 0 else []
            buffer_b = [remain_b] if remain_b.shape[0] > 0 else []
            buffer_y = [remain_y] if remain_y.shape[0] > 0 else []
            buffer_idx = [remain_idx] if remain_idx.shape[0] > 0 else []
            total_len = remain_a.shape[0]

        i = j
        if pbar is not None:
            pbar.update(i - prev_i)

    if buffer_a:
        concat_a = np.concatenate(buffer_a, axis=0)
        concat_b = np.concatenate(buffer_b, axis=0)
        concat_y = np.concatenate(buffer_y, axis=0)
        concat_idx = np.concatenate(buffer_idx, axis=0)
        if concat_a.shape[0] > 0:
            shard_path = out_dir / f"shard_w{cfg.worker_id or 0}_{shard_idx:06d}.npz"
            np.savez_compressed(shard_path, action_a=concat_a, action_b=concat_b, label=concat_y, ds_idx=concat_idx)
            mf = json.loads(worker_manifest.read_text())
            mf["shards"].append({
                "path": shard_path.name,
                "start": (range_start + total_written_global),
                "length": concat_a.shape[0],
            })
            worker_manifest.write_text(json.dumps(mf, indent=2))
            logging.info(
                f"Saved shard {shard_idx} [{range_start + total_written_global}, {range_start + total_written_global + concat_a.shape[0]}) -> {shard_path}"
            )
    if pbar is not None:
        pbar.n = local_count
        pbar.refresh()
        pbar.close()


def _spawn_workers_and_merge(cfg: PairBuildConfig):
    logging.basicConfig(level=logging.INFO, force=True)
    stage1_dir = _resolve_stage1_dir(cfg)
    out_dir = _resolve_output_dir(cfg)
    stage1_dir_str = str(stage1_dir)
    out_dir_str = str(out_dir)
    # 预先加载数据以确定总样本数
    _, _, dataset = _create_dataset_and_stats(cfg)
    total_samples = len(dataset)
    _maybe_mkdir(out_dir)

    try:
        gpu_devs = jax.devices("gpu")
    except Exception:
        gpu_devs = []
    if len(gpu_devs) > 0:
        devices = [f"cuda:{i}" for i in range(len(gpu_devs))]
    else:
        devices = ["cpu"]

    n = cfg.num_workers
    if n <= 1:
        return _worker_main(cfg)

    per = (total_samples + n - 1) // n
    procs = []
    for wid in range(n):
        rs = wid * per
        re = min((wid + 1) * per, total_samples)
        if rs >= re:
            continue
        dev = devices[wid % len(devices)]
        cmd = [
            sys.executable,
            str(Path(__file__).resolve()),
            f"--exp-name={cfg.exp_name}",
            f"--num-workers=1",
            f"--worker-id={wid}",
            f"--device={dev}",
            f"--range-start={rs}",
            f"--range-end={re}",
            f"--output-dir={out_dir_str}",
            f"--gen-batch-size={cfg.gen_batch_size}",
            f"--samples-per-shard={cfg.samples_per_shard}",
            f"--stage1-checkpoint-dir={stage1_dir_str}",
        ]
        cmd = [c for c in cmd if c != ""]
        env = os.environ.copy()
        if dev.startswith("cuda:"):
            env["CUDA_VISIBLE_DEVICES"] = dev.split(":")[1]
        else:
            env["CUDA_VISIBLE_DEVICES"] = ""
        logging.info(f"Launching worker {wid} range=[{rs},{re}) on {dev}")
        procs.append(subprocess.Popen(cmd, env=env))

    for p in procs:
        p.wait()

    merged = {
        "good_steps": cfg.good_steps,
        "bad_steps": cfg.bad_steps,
        "action_horizon": cfg.model.action_horizon,
        "action_dim": cfg.model.action_dim,
        "stage1_checkpoint_dir": stage1_dir_str,
        "data_repo_id": cfg.data.repo_id,
        "data_prompt_from_task": getattr(cfg.data.base_config, "prompt_from_task", False),
        "total_samples": total_samples,
        "samples_per_shard": cfg.samples_per_shard,
        "shards": [],
    }
    for mf_path in sorted(out_dir.glob("manifest_worker_*.json")):
        mf = json.loads(mf_path.read_text())
        merged["shards"].extend(mf.get("shards", []))
    (out_dir / "manifest.json").write_text(json.dumps(merged, indent=2))
    logging.info(f"Merged manifest with {len(merged['shards'])} shards -> {out_dir/'manifest.json'}")


def main(cfg: PairBuildConfig):
    if cfg.worker_id is None and cfg.num_workers > 1:
        _spawn_workers_and_merge(cfg)
    else:
        _worker_main(cfg)


if __name__ == "__main__":
    main(tyro.cli(PairBuildConfig))


