import numpy as np
import jax
import jax.numpy as jnp

from openpi.models import pi0_config


def make_fake_pair(cfg):
    obs = cfg.fake_obs(batch_size=1)
    a = jax.random.normal(jax.random.key(1), (1, cfg.action_horizon, cfg.action_dim))
    b = jax.random.normal(jax.random.key(2), (1, cfg.action_horizon, cfg.action_dim))
    label = jnp.array([1.0], dtype=jnp.float32)
    return obs, a, b, label


def main():
    key = jax.random.key(0)
    cfg = pi0_config.Pi0Config(enable_action_comparator=True)
    model = cfg.create(key)

    # Fake pair batch
    obs, a, b, label = make_fake_pair(cfg)

    # Prefill with queries and compare
    context = model._prefill_vlm_with_queries(obs)
    logits = model.compare_actions_with_context(obs, context, a, b)
    print("logits shape:", logits.shape)

    # Sample-and-compare with K candidates
    out = model.sample_and_compare(key, obs, num_candidates=3, num_steps=5)
    print("selected action shape:", out.shape)


if __name__ == "__main__":
    main()


