import jax.numpy as jnp


def maze_success_fn(observations, goals):
    coords = observations[..., :2]
    successes = jnp.linalg.norm(coords - goals, axis=-1) <= 0.5
    return successes


def cube_success_fn(observations, goals):
    xs = observations[..., 19::9]
    ys = observations[..., 20::9]
    zs = observations[..., 21::9]
    xyzs = jnp.stack([xs, ys, zs], axis=-1)
    goal_xyzs = goals.reshape((*goals.shape[:-1], -1, 3))

    successes = (jnp.linalg.norm(xyzs - goal_xyzs, axis=-1) <= 0.4).all(axis=-1)
    return successes


def puzzle_success_fn(observations, goals):
    button_states = observations[..., 20::4]
    successes = (jnp.abs(button_states - goals) < 0.5).all(axis=-1)
    return successes

def scene_success_fn(observations, goals):
    button_states = observations[..., 20::4]
    successes = (jnp.abs(button_states - goals) < 0.5).all(axis=-1)
    return successes

def get_success_fn(env_name):
    """Return the success function for the given environment."""
    if 'antmaze' in env_name or 'humanoidmaze' in env_name:
        return maze_success_fn
    elif 'cube' in env_name:
        return cube_success_fn
    elif 'puzzle' in env_name:
        return puzzle_success_fn
    else:
        return None
        #raise ValueError(f'Unsupported environment: {env_name}')
