import jax
import jax.numpy as jnp

from openpi.models import pi0_config


def main():
    key = jax.random.key(0)
    cfg = pi0_config.Pi0Config()
    model = cfg.create(key)

    obs, act = cfg.inputs_spec(batch_size=1)
    # Fake inputs
    obs = cfg.fake_obs(batch_size=1)
    act = cfg.fake_act(batch_size=1)

    # Training forward (loss)
    loss = model.compute_loss(key, obs, act, train=True)
    print("loss shape:", loss.shape)

    # Inference forward (sample)
    actions = model.sample_actions(key, obs, num_steps=5)
    print("actions shape:", actions.shape)


if __name__ == "__main__":
    main()


