import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx

from openpi.models import pi0_config
from openpi.policies import policy as _policy


def main():
    # Create a tiny dummy model with comparator enabled
    cfg = pi0_config.Pi0Config(enable_action_comparator=True)
    model = cfg.create(jax.random.key(0))

    # Fake single-batch inputs
    obs = {
        "images": {
            "base_0_rgb": np.zeros((224, 224, 3), dtype=np.float32),
            "left_wrist_0_rgb": np.zeros((224, 224, 3), dtype=np.float32),
            "right_wrist_0_rgb": np.zeros((224, 224, 3), dtype=np.float32),
        },
        "image_masks": {
            "base_0_rgb": True,
            "left_wrist_0_rgb": True,
            "right_wrist_0_rgb": True,
        },
        "state": np.zeros((cfg.action_dim,), dtype=np.float32),
        "prompt": "Test task",
        "num_candidates": 3,
    }

    pol = _policy.Policy(model, rng=jax.random.key(42))
    out = pol.infer(obs)
    print("actions shape:", out["actions"].shape)
    assert out["actions"].shape == (cfg.action_horizon, cfg.action_dim)
    print("OK")


if __name__ == "__main__":
    main()


