import flax.struct as struct
import jax
import jax.numpy as jnp

from nais.gym.base import EnvState, LogRewardBase


def at(x, indices):
    if isinstance(x, jax.Array) and x.ndim >= 2:
        return x[indices]
    return x


def vstack(x: jax.Array, y: jax.Array):
    if isinstance(x, jax.Array) and x.ndim >= 2:
        return jnp.vstack([x, y])
    # Hacky: focus on the value of most recent items.
    # This should be avoided when using the lines environment.
    return y


@struct.dataclass
class ReplayBuffer:
    size: int = struct.field(pytree_node=False)
    states: EnvState  # (size, dim)
    log_rewards: LogRewardBase  # (size,)

    @classmethod
    def create(cls, states: EnvState):
        return cls(
            size=states.batch_size,
            states=states,
            log_rewards=-jnp.inf * jnp.ones((states.batch_size,)),
        )


def add_to_buffer(replay_buffer: ReplayBuffer, states: EnvState, log_rewards: jax.Array):
    # We first concatenate both states and buffer
    all_states = jax.tree.map(lambda x, y: vstack(x, y), replay_buffer.states, states)
    # We also include the log rewards
    all_log_rewards = jnp.hstack([replay_buffer.log_rewards, log_rewards])
    # Then we concatenate both all_states and log_rewards
    r_states = jnp.hstack([all_log_rewards[..., None], all_states.state])  # (S + B, dim + 1)
    u_states, indices = jnp.unique(
        -r_states,
        axis=0,
        size=replay_buffer.size,
        return_index=True,
        fill_value=jnp.nan,
    )
    u_log_rewards = all_log_rewards[indices]
    # We slice all_states according to the indices of the highest-scoring samples
    new_states = jax.tree.map(lambda x: at(x, indices), all_states)
    # Consider the case in which there are fewer elements than in the buffer's size
    new_log_rewards = jnp.where(
        jnp.any(jnp.isnan(u_states), axis=1),
        -jnp.inf,
        u_log_rewards,
    )

    return replay_buffer.replace(
        states=new_states,
        log_rewards=new_log_rewards,
    )


def sample_from_buffer(
    replay_buffer: ReplayBuffer, batch_size: int, key: jax.random.PRNGKey
) -> tuple[EnvState, jax.Array, jax.random.PRNGKey]:
    # This can be implemented with tree_util
    # We first select the batch_size indices from 0, ..., size - 1
    # according to their reward
    log_p = replay_buffer.log_rewards - jax.nn.logsumexp(replay_buffer.log_rewards, axis=0)
    indices = jax.random.choice(
        key=key,
        a=jnp.arange(replay_buffer.size),
        replace=True,
        shape=(batch_size,),
        p=jnp.exp(log_p),
    )
    # Slice the arrays
    log_rewards = replay_buffer.log_rewards[indices]
    states = jax.tree.map(lambda x: at(x, indices), replay_buffer.states)

    _, key = jax.random.split(key, 2)
    return states, log_rewards, key
