import flashbax as fbx
import jax
import jax.numpy as jnp
from flax import struct
from jax import Array

from medium_rl.envs.sequence_env import SequenceEnv


@struct.dataclass
class SubTrajectory:
    obs: Array
    legal_action_mask: Array
    action: Array
    done: Array
    reward: Array


def make_dummy_subtrajectory(env: SequenceEnv):
    return SubTrajectory(
        obs=jnp.zeros((env.max_len,), dtype=env.obs_dtype),
        legal_action_mask=jnp.zeros((env.max_len, env.num_tokens), dtype=jnp.int32),
        action=jnp.zeros((env.max_len,), dtype=jnp.int32),
        done=jnp.zeros((env.max_len), dtype=jnp.int32),
        reward=jnp.zeros((1,), dtype=jnp.float32),
    )


def make_jit_trajectory_buffer(max_length: int, min_length: int, sample_batch_size: int, add_batches: bool):
    buffer = fbx.make_item_buffer(
        max_length=max_length,
        min_length=min_length,
        sample_batch_size=sample_batch_size,
        add_batches=add_batches,
    )
    buffer = buffer.replace(
        init=jax.jit(buffer.init),
        add=jax.jit(buffer.add, donate_argnums=0),
        sample=jax.jit(buffer.sample),
        can_sample=jax.jit(buffer.can_sample),
    )
    return buffer
