import jax
import jax.numpy as jnp
from jax import Array

from medium_rl.config import TGMConfig


def random_action(
    legal_action_mask: Array,  # [B, D]
    rng: jax.random.PRNGKey,
):
    probs = legal_action_mask / legal_action_mask.sum(axis=-1, keepdims=True)
    return jax.random.categorical(rng, jnp.log(probs))


def make_alg_policy(eps: float, alg_cfg, inv_temp_mod: float = 1.0):
    if alg_cfg.name == "TGM":
        return make_tgm_policy(eps, alg_cfg, inv_temp_mod * alg_cfg.gen_inv_temp)
    elif alg_cfg.name == "SAC":
        return make_eps_inv_temp_softmax_policy(eps, 1 * inv_temp_mod)
    elif alg_cfg.name == "PPO":
        return make_eps_inv_temp_softmax_policy(eps, 1 * inv_temp_mod)
    elif alg_cfg.name == "TGMW":
        return make_tgmw_policy(eps, alg_cfg, inv_temp_mod)
    elif alg_cfg.name == "TGMP":
        return make_eps_inv_temp_softmax_policy(eps, 1 * inv_temp_mod)


def make_tgm_policy(eps: float, alg_cfg: TGMConfig, inv_temp_mod: float = 1.0):
    def policy_fn(logits, legal_action_mask, step, policy_rng):
        logits = alg_cfg.q * alg_cfg.alpha * alg_cfg.q_fn(logits) + alg_cfg.omega * logits
        return eps_inv_temp_softmax_policy(logits, legal_action_mask, step, eps, inv_temp_mod, policy_rng)

    return policy_fn


def make_tgmw_policy(eps: float, alg_cfg: TGMConfig, inv_temp_mod: float = 1.0):
    def policy_fn(logits, legal_action_mask, step, policy_rng):
        if alg_cfg.scaling == "linear":
            omega = alg_cfg.init_omega + step * alg_cfg.omega_c
        elif alg_cfg.scaling == "exponential":
            omega = jnp.exp(step / alg_cfg.omega_c)
        else:
            raise ValueError(f"Invalid scaling {alg_cfg.scaling}")

        logits = alg_cfg.q * alg_cfg.alpha * alg_cfg.q_fn(logits) + omega * logits
        return eps_inv_temp_softmax_policy(logits, legal_action_mask, step, eps, inv_temp_mod, policy_rng)

    return policy_fn


def make_eps_inv_temp_softmax_policy(eps: float, inv_temp: float):
    def policy_fn(logits, legal_action_mask, step, policy_rng):
        return eps_inv_temp_softmax_policy(logits, legal_action_mask, step, eps, inv_temp, policy_rng)

    return policy_fn


@jax.jit
def eps_inv_temp_softmax_policy(
    logits: Array,  # [B, T, D]
    legal_action_mask: Array,  # [B, D]
    curr_step: int,
    eps: float,
    inv_temp: float,
    rng: jax.random.PRNGKey,
):
    # Softmax action
    softmax_rng, action_rng, mask_rng = jax.random.split(rng, 3)
    logits = logits[:, curr_step, :]  # Extract only logits for current token.
    logits = jnp.where(legal_action_mask == 1, logits * inv_temp, -jnp.inf)
    action = jax.random.categorical(softmax_rng, logits, axis=-1)

    # Uniform action
    probs = legal_action_mask / legal_action_mask.sum(axis=-1, keepdims=True)
    random_actions = jax.random.categorical(action_rng, jnp.log(probs))
    mask = jax.random.bernoulli(mask_rng, p=eps, shape=action.shape)

    return jnp.where(mask, random_actions, action)
