import flashbax as fbx
import jax
import jax.numpy as jnp
from flax import struct
from jax import Array

from medium_rl.envs.sequence_env import SequenceEnv


@struct.dataclass
class Transition:
    obs: Array
    next_obs: Array
    legal_action_mask: Array
    next_legal_action_mask: Array
    action: Array
    terminating: Array
    done: Array
    reward: Array
    val: Array
    log_prob: Array
    step: Array


def make_dummy_transition(env: SequenceEnv):
    return Transition(
        obs=jnp.zeros((env.max_len,), dtype=env.obs_dtype),
        next_obs=jnp.zeros((env.max_len,), dtype=env.obs_dtype),
        legal_action_mask=jnp.zeros((env.num_tokens,), dtype=jnp.int32),
        next_legal_action_mask=jnp.zeros((env.num_tokens,), dtype=jnp.int32),
        action=jnp.zeros((), dtype=jnp.int32),
        terminating=jnp.zeros((), dtype=jnp.int32),
        done=jnp.zeros((), dtype=jnp.int32),
        reward=jnp.zeros((), dtype=jnp.float32),
        val=jnp.zeros((), dtype=jnp.float32),
        log_prob=jnp.zeros((), dtype=jnp.float32),
        step=jnp.zeros((), dtype=jnp.int32),
    )


def make_jit_transition_buffer(max_length: int, min_length: int, sample_batch_size: int, add_batches: float):
    buffer = fbx.make_item_buffer(
        max_length=max_length,
        min_length=min_length,
        sample_batch_size=sample_batch_size,
        add_batches=add_batches,
    )
    buffer = buffer.replace(
        init=jax.jit(buffer.init),
        add=jax.jit(buffer.add, donate_argnums=0),
        sample=jax.jit(buffer.sample),
        can_sample=jax.jit(buffer.can_sample),
    )
    return buffer
