import jax
import jax.numpy as jnp
import jax.random as random

from nais.gflownet import GFlowNet
from nais.gym.base import EnvironmentConfig, EnvState, LogRewardBase, LogRewardConfig


def log_binomial(n: int, k: int) -> float:
    num = jnp.log(jnp.arange(k + 1, n + 1))
    den = jnp.log(jnp.arange(1, n - k + 1))
    return jnp.sum(num - den)


# This will be environment-dependent
def apply_fn(gflownet_key: tuple[GFlowNet, jax.Array], _) -> GFlowNet:
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state
    is_active = env_state.stopped == 0.0

    out_pf = gflownet.pf.sample_actions(env_state, key=key)
    safe_actions = jnp.where(is_active, out_pf.actions, 0)

    current_values = env_state.state[env_state.batch_ids, safe_actions]
    updates = jnp.where(is_active, 1.0, current_values)
    state = env_state.state.at[env_state.batch_ids, safe_actions].set(updates)

    total = state.sum(axis=1)
    stopped = jnp.where(
        is_active,
        (total == env_state.max_trajectory_length).astype(env_state.stopped.dtype),
        env_state.stopped,
    )
    is_initial = jnp.where(
        is_active,
        jnp.zeros_like(env_state.is_initial),
        env_state.is_initial,
    )

    forward_mask = jnp.where(state == 0.0, 1.0, 0.0)
    backward_mask = jnp.where(state > 0.0, 1.0, 0.0)

    env_state = env_state.replace(
        state=state,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        stopped=stopped,
        is_initial=is_initial,
    )

    log_pf = gflownet.state.log_pf.at[:, gflownet.state.idx].set(jnp.where(is_active, out_pf.log_pf, 0.0))

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pf=log_pf,
        idx=gflownet.state.idx + 1,
    )

    return (gflownet, out_pf.key), (out_pf.actions, is_active, env_state)


def backward_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state

    is_active = env_state.is_initial == 0.0

    out_pb = gflownet.pb.sample_actions(env_state, key=key)

    safe_actions = jnp.where(is_active, out_pb.actions, 0)

    current_values = env_state.state[env_state.batch_ids, safe_actions]
    updates = jnp.where(is_active, 0.0, current_values)
    state = env_state.state.at[env_state.batch_ids, safe_actions].set(updates)

    stopped = jnp.where(is_active, jnp.zeros_like(env_state.stopped), env_state.stopped)

    is_initial = jnp.where(
        is_active,
        (state.sum(axis=1) == 0).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )

    forward_mask = jnp.where(state == 0.0, 1.0, 0.0)
    backward_mask = jnp.where(state > 0.0, 1.0, 0.0)

    env_state = env_state.replace(
        state=state,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        stopped=stopped,
        is_initial=is_initial,
    )

    idx = gflownet.state.idx - 1

    log_pb = gflownet.state.log_pb.at[:, idx].set(jnp.where(is_active, out_pb.log_pb, 0.0))

    gflownet_state = gflownet.state.replace(
        env_state=env_state,
        log_pb=log_pb,
        idx=idx,
    )

    gflownet.state = gflownet_state

    return (gflownet, out_pb.key), (out_pb.actions, is_active, env_state)


def factory(k: int, s: int, config: EnvironmentConfig):
    return EnvState(
        state=jnp.zeros((config.batch_size, s)),
        forward_mask=jnp.ones((config.batch_size, s)),
        backward_mask=jnp.zeros((config.batch_size, s)),
        batch_ids=jnp.arange(config.batch_size),
        stopped=jnp.zeros((config.batch_size,)),
        is_initial=jnp.ones((config.batch_size,)),
        max_trajectory_length=k,
        batch_size=config.batch_size,
    )


class LogReward(LogRewardBase):
    def __init__(self, k: int, s: int, config: LogRewardConfig, *, seed: int = 42):
        super().__init__(config)
        self.k = k
        self.s = s

        self.log_utilities = random.t(key=random.key(seed), shape=(s,), df=1)

    def __call__(self, state: EnvState):
        return (state.state * self.log_utilities).sum(axis=1) / self.temperature
