import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
from divgfn.utils import HasState, merge


def at(x: jax.Array, indices: jax.Array):
    if isinstance(x, jax.Array) and x.ndim >= 1:
        return x[indices]
    return x


@struct.dataclass
class ReplayBuffer:
    # This resembles a ModeQueue
    x: HasState
    y: jax.Array
    mask: jax.Array
    recency: jax.Array
    size: int = struct.field(pytree_node=False)

    @classmethod
    def create(cls, state: HasState):
        bs, _ = state.state.shape
        return cls(
            x=state,
            y=-jnp.ones((bs,)) * jnp.inf,
            mask=jnp.zeros((bs,), dtype=jnp.bool),
            recency=jnp.zeros((bs,)),
            size=bs,
        )


@jax.jit
def push(buffer: ReplayBuffer, states: HasState, log_rewards: jax.Array, iteration: jax.Array) -> ReplayBuffer:
    # We concatenate states with buffer, sort, and get the `size` uniquely most valuable
    all_states = jax.tree_util.tree_map(merge, buffer.x, states)
    # We do the same for y
    recency = jnp.hstack(
        [
            buffer.recency - 1,  # Age existing states
            jnp.full_like(log_rewards, iteration),  # New states get current iteration
        ]
    )
    recency_bonus = 0.01 * recency / (iteration + 1)

    all_logr = jnp.hstack([buffer.y, log_rewards])
    sortable = jnp.hstack([-(all_logr + recency_bonus)[:, None], all_states.state])

    sortable, indices = jnp.unique(
        sortable,
        size=buffer.size,
        fill_value=-jnp.inf,
        return_index=True,
        axis=0,
    )
    mask = ~jnp.isinf(sortable).any(axis=1)
    x = jax.tree_util.tree_map(lambda s: at(s, indices), states)
    y = -sortable[:, 0]
    stored_recency = recency[indices]

    return buffer.replace(mask=mask, x=x, y=y, recency=stored_recency)


def sample(buffer: ReplayBuffer, batch_size: int, key: jax.Array):
    # Sample with replacement
    key, subkey = jax.random.split(key, 2)
    logits = nnx.log_softmax(jnp.where(buffer.mask, buffer.y, -jnp.inf))
    indices = jax.random.categorical(
        key=subkey,
        logits=logits,
        replace=True,
        shape=(batch_size,),
    )

    # Take indices from buffer
    x = jax.tree_util.tree_map(lambda s: at(s, indices), buffer.x)
    y = buffer.y[indices]

    x = x.replace(batch_ids=jnp.arange(batch_size))

    return x, y, key
