from flax import struct
import jax.numpy as jnp
import jax
import distrax


@struct.dataclass
class Preferences:
    obs: jnp.ndarray
    actions: jnp.ndarray
    tot_rewards: jnp.ndarray


def preference_model(rng, x, y, temperature):
    # Preference model, it returns true if x is preferred to y, false otherwise
    return (
        distrax.Softmax(jnp.array([x, y]), temperature=temperature).sample(seed=rng)
        == 0
    )


def judge(rng, obs1, action1, tot_reward1, obs2, action2, tot_reward2, args):
    prefer_1 = preference_model(rng, tot_reward1, tot_reward2, args.judge_temp)
    obs = jnp.where(prefer_1, jnp.array([obs1, obs2]), jnp.array([obs2, obs1]))
    actions = jnp.where(
        prefer_1, jnp.array([action1, action2]), jnp.array([action2, action1])
    )
    tot_rewards = jnp.where(
        prefer_1,
        jnp.array([tot_reward1, tot_reward2]),
        jnp.array([tot_reward2, tot_reward1]),
    )
    return Preferences(obs=obs, actions=actions, tot_rewards=tot_rewards)
