from typing import Tuple

import jax
import jax.numpy as jnp
from flax import struct
from jax import random, vmap


@struct.dataclass
class EnvHistory:
    gamma: float = struct.field(pytree_node=False)  # won't be traced
    num_before: int = struct.field(pytree_node=False)  # won't be traced
    num_actions: int = struct.field(pytree_node=False)  # won't be traced
    obs: jnp.ndarray  # (T, *obs_shape)
    a: jnp.ndarray  # (T,)
    r: jnp.ndarray  # (T,)
    done: jnp.ndarray  # (T,)
    acc_r: jnp.ndarray  # (T,)
    curr_step: jnp.ndarray  # (1,)


def make_history(
    num_before: int, num_actions: int, obs_shape: Tuple[int, ...], gamma: float = 0.99
) -> EnvHistory:
    T = num_before + 1
    return EnvHistory(
        gamma=gamma,  # Default discount factor
        num_before=num_before,
        num_actions=num_actions,
        obs=jnp.zeros((T,) + obs_shape),
        a=jnp.zeros((T,), dtype=jnp.int32),
        r=jnp.zeros((T,), dtype=jnp.float32),
        done=jnp.zeros((T,), dtype=jnp.bool_),
        acc_r=jnp.zeros((T,), dtype=jnp.float32),
        curr_step=jnp.zeros((1,), dtype=jnp.int32),  # Current step in the history
    )


@jax.jit
def reset(hist: EnvHistory, obs: jnp.ndarray, key) -> EnvHistory:
    """Reset the history with a new observation."""
    T = hist.obs.shape[0]
    obs_hist = jnp.tile(obs[None, ...], (T, *([1] * obs.ndim)))  # (T, *obs_shape)
    a = random.randint(
        key, shape=(T,), minval=0, maxval=hist.num_actions, dtype=jnp.int32
    )
    return hist.replace(  # type: ignore
        obs=obs_hist,
        a=a,
        r=jnp.zeros((T,), dtype=jnp.float32),
        done=jnp.zeros((T,), dtype=jnp.bool_),
        acc_r=jnp.zeros((T,), dtype=jnp.float32),
        curr_step=jnp.zeros((1,), dtype=jnp.int32),  # Reset current step
    )


@jax.jit
def reset_at_done(
    hist: EnvHistory, obs: jnp.ndarray, key, done: jnp.ndarray
) -> EnvHistory:
    reset_hist = reset(hist, obs, key)

    # Choose between `reset_hist` and `hist` based on `done`
    return hist.replace(  # type: ignore
        num_before=hist.num_before,
        num_actions=hist.num_actions,
        obs=jnp.where(done, reset_hist.obs, hist.obs),
        a=jnp.where(done, reset_hist.a, hist.a),
        r=jnp.where(done, reset_hist.r, hist.r),
        done=jnp.where(done, reset_hist.done, hist.done),
        acc_r=jnp.where(done, reset_hist.acc_r, hist.acc_r),
        curr_step=jnp.where(done, reset_hist.curr_step, hist.curr_step),
    )


@jax.jit
def step(hist: EnvHistory, obs, a, r, done) -> EnvHistory:
    obs_hist = jnp.concatenate([hist.obs[1:], obs[None, :]], axis=0)
    a_hist = jnp.concatenate([hist.a[1:], a[None]], axis=0)
    r_hist = jnp.concatenate([hist.r[1:], r[None]], axis=0)
    done_hist = jnp.concatenate([hist.done[1:], done[None]], axis=0)
    acc_r_hist = jnp.concatenate(
        [hist.acc_r[1:], hist.acc_r[-1] + hist.gamma**hist.curr_step * r[None]], axis=0
    )
    curr_step_hist = hist.curr_step + 1
    return hist.replace(  # type: ignore
        obs=obs_hist,
        a=a_hist,
        r=r_hist,
        done=done_hist,
        acc_r=acc_r_hist,
        curr_step=curr_step_hist,
    )


def make_batch_history(
    batch_size: int,
    num_before: int,
    num_actions: int,
    obs_shape: Tuple[int, ...],
    gamma: float = 0.99,
) -> EnvHistory:
    """Create a batch of histories."""
    return jax.vmap(lambda _: make_history(num_before, num_actions, obs_shape, gamma))(
        jnp.arange(batch_size)
    )


history_reset = vmap(reset)
history_step = vmap(step)
history_reset_at_done = vmap(reset_at_done)
