import dataclasses
import json
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
import tyro

from openpi.shared import path_utils


@dataclasses.dataclass
class InspectConfig:
    """Inspect stored offline action pairs to verify value scales.

    It prints summary statistics for action_a, action_b, and (a-b), and basic
    label distribution. Use this to confirm whether saved actions are already
    normalized (e.g., around [-1, 1]) or in real-world units.
    """

    # Directory containing manifest.json and shard_*.npz
    pairs_dir: str = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_INSPECT_PAIRS_DIR",
            "OPENPI_PI05_PAIRS_OFFLINE_GT_DIR",
            "OPENPI_PAIRS_OFFLINE_GT_DIR",
        )
    )

    # Max number of pairs to sample in total (across shards)
    max_pairs: int = 4096

    # Limit number of shards to scan; None means all
    max_shards: int | None = 8

    # Number of rows to print as examples (after flattening H dimension)
    show_samples: int = 3

    # Random seed for sampling within shards
    seed: int = 0

    # Whether to print per-dimension stats (across sampled pairs and horizon)
    per_dim: bool = False


def _load_manifest(pairs_dir: Path) -> dict:
    mf_path = pairs_dir / "manifest.json"
    if not mf_path.exists():
        raise FileNotFoundError(f"manifest.json not found under {pairs_dir}")
    return json.loads(mf_path.read_text())


def _iter_shards(pairs_dir: Path, manifest: dict, max_shards: int | None) -> List[Path]:
    shards = manifest.get("shards", [])
    shard_paths = [pairs_dir / s["path"] for s in shards]
    if max_shards is not None:
        shard_paths = shard_paths[: max_shards]
    return shard_paths


def _sample_indices(n: int, k: int, rng: random.Random) -> np.ndarray:
    k = min(k, n)
    if k <= 0:
        return np.zeros((0,), dtype=np.int64)
    # Uniform sample without replacement
    return np.asarray(rng.sample(range(n), k), dtype=np.int64)


def _concat_stats(arrs: List[np.ndarray]) -> np.ndarray:
    if not arrs:
        return np.zeros((0,), dtype=np.float32)
    return np.concatenate(arrs, axis=0)


def _compute_stats(x: np.ndarray) -> dict:
    # 仅对 D 维逐位统计：
    # - x 为 [N, H, D] 时：在 (N, H) 上聚合，返回 [D]
    # - x 为 [N, D] 时：在 N 上聚合，返回 [D]
    # - x 为 [N] 时：返回标量
    if x.size == 0:
        return {
            "min": float("nan"),
            "max": float("nan"),
            "mean": float("nan"),
            "std": float("nan"),
            "q01": float("nan"),
            "q99": float("nan"),
        }

    if x.ndim == 3:
        # [N, H, D] -> 在 (N, H) 维聚合，得到 [D]
        axis = (0, 1)
        return {
            "min": np.min(x, axis=axis),
            "max": np.max(x, axis=axis),
            "mean": np.mean(x, axis=axis),
            "std": np.std(x, axis=axis),
            "q01": np.quantile(x, 0.01, axis=axis),
            "q99": np.quantile(x, 0.99, axis=axis),
        }
    elif x.ndim == 2:
        # [N, D] -> 在 N 维聚合，得到 [D]
        axis = 0
        return {
            "min": np.min(x, axis=axis),
            "max": np.max(x, axis=axis),
            "mean": np.mean(x, axis=axis),
            "std": np.std(x, axis=axis),
            "q01": np.quantile(x, 0.01, axis=axis),
            "q99": np.quantile(x, 0.99, axis=axis),
        }
    else:
        # [N] -> 标量
        overall = x.reshape(-1)
        return {
            "min": float(np.min(overall)),
            "max": float(np.max(overall)),
            "mean": float(np.mean(overall)),
            "std": float(np.std(overall)),
            "q01": float(np.quantile(overall, 0.01)),
            "q99": float(np.quantile(overall, 0.99)),
        }


def _compute_per_dim_stats(x: np.ndarray) -> List[dict]:
    # shape [..., D]
    if x.ndim == 3:
        flat = x.reshape(-1, x.shape[-1])
    elif x.ndim == 2:
        flat = x
    else:
        flat = x.reshape(-1, 1)
    out = []
    D = flat.shape[-1]
    for d in range(D):
        col = flat[:, d]
        out.append(
            {
                "dim": d,
                "min": float(np.min(col)) if col.size > 0 else float("nan"),
                "max": float(np.max(col)) if col.size > 0 else float("nan"),
                "mean": float(np.mean(col)) if col.size > 0 else float("nan"),
                "std": float(np.std(col)) if col.size > 0 else float("nan"),
                "q01": float(np.quantile(col, 0.01)) if col.size > 0 else float("nan"),
                "q99": float(np.quantile(col, 0.99)) if col.size > 0 else float("nan"),
            }
        )
    return out


def _looks_normalized(a: np.ndarray, b: np.ndarray) -> Tuple[bool, float]:
    # Heuristic: if max abs <= 1.2, treat as normalized
    max_abs = float(max(np.max(np.abs(a)), np.max(np.abs(b)))) if a.size and b.size else float("nan")
    return (max_abs <= 1.2), max_abs


def main(cfg: InspectConfig):
    pairs_env_hint = (
        "OPENPI_INSPECT_PAIRS_DIR",
        "OPENPI_PI05_PAIRS_OFFLINE_GT_DIR",
        "OPENPI_PAIRS_OFFLINE_GT_DIR",
    )
    pairs_dir_str = path_utils.require_path(
        cfg.pairs_dir,
        description="pairs_dir",
        env_vars=pairs_env_hint,
        cli_flag="--pairs-dir",
    )
    pairs_dir = Path(pairs_dir_str)
    manifest = _load_manifest(pairs_dir)
    shards = _iter_shards(pairs_dir, manifest, cfg.max_shards)
    if not shards:
        raise RuntimeError("No shards found in manifest.")

    rng = random.Random(cfg.seed)

    a_chunks: List[np.ndarray] = []
    b_chunks: List[np.ndarray] = []
    y_chunks: List[np.ndarray] = []
    collected = 0

    for spath in shards:
        if collected >= cfg.max_pairs:
            break
        with np.load(spath) as data:
            a = data["action_a"]  # [N, H, D]
            b = data["action_b"]  # [N, H, D]
            y = data["label"]     # [N]
            n = a.shape[0]
            take = min(cfg.max_pairs - collected, n)
            idxs = _sample_indices(n, take, rng)
            a_chunks.append(a[idxs])
            b_chunks.append(b[idxs])
            y_chunks.append(y[idxs])
            collected += take

    if collected == 0:
        raise RuntimeError("No pairs sampled. Check shards and max_pairs.")

    A = _concat_stats(a_chunks)
    B = _concat_stats(b_chunks)
    Y = _concat_stats(y_chunks)

    # Shapes
    print(f"Sampled pairs: {A.shape[0]}  shape(H,D)={A.shape[1:]}  from {len(shards)} shard(s)")

    # Aggregate stats
    Sa = _compute_stats(A)
    Sb = _compute_stats(B)
    Df = _compute_stats(A - B)
    y_mean = float(np.mean(Y)) if Y.size else float("nan")
    y_pos = float(np.mean((Y > 0.5).astype(np.float32))) if Y.size else float("nan")
    looks_norm, max_abs = _looks_normalized(A, B)

    print("action_a stats:", Sa)
    print("action_b stats:", Sb)
    print("(a-b) stats:", Df)
    print(f"labels: mean={y_mean:.4f}, pos_frac={y_pos:.4f}")
    print(f"looks_normalized={looks_norm} (max_abs={max_abs:.4f})")

    if cfg.per_dim:
        print("Per-dim stats for action_a:")
        for s in _compute_per_dim_stats(A):
            print(s)
        print("Per-dim stats for action_b:")
        for s in _compute_per_dim_stats(B):
            print(s)

    # Show a few samples (flatten H for compact view)
    if cfg.show_samples > 0:
        n_show = min(cfg.show_samples, A.shape[0])
        Af = A.reshape(A.shape[0], -1)
        Bf = B.reshape(B.shape[0], -1)
        for i in range(n_show):
            print(f"sample[{i}] a[:8]=" + np.array2string(Af[i][:8], precision=3))
            print(f"sample[{i}] b[:8]=" + np.array2string(Bf[i][:8], precision=3))


if __name__ == "__main__":
    main(tyro.cli(InspectConfig))


