import functools
from typing import Callable

import chex
import einops
import jax
import jax.numpy as jnp
from flax import struct

Prior = Callable[[chex.PRNGKey], chex.Array]


# @dataclasses.dataclass(frozen=True, kw_only=True)
@struct.dataclass
class MarkovRewardProcess:
    transition_kernel: chex.Array
    cumulants: chex.Array

    @functools.cached_property
    def num_states(self) -> int:
        return self.transition_kernel.shape[0]

    @functools.cached_property
    def reward_dim(self) -> int:
        if len(self.cumulants.shape) == 1:
            return 1
        return self.cumulants.shape[-1]

    def sample_from_state(self, rng: chex.PRNGKey, source_state: int) -> int:
        probs = self.transition_kernel[source_state, :]
        next_state = jax.random.choice(rng, self.num_states, p=probs)
        cumulant = self.cumulants[source_state]
        return cumulant, next_state

    def monte_carlo_return(
        self,
        rng: chex.PRNGKey,
        source_state: int,
        discount: float,
        max_steps: int = 200,
    ) -> chex.Array:
        def env_step(
            timestep: int, trace: tuple[int, chex.Array]
        ) -> tuple[int, chex.Array]:
            key = jax.random.fold_in(rng, timestep)
            cumulant, next_state = self.sample_from_state(key, trace[0])
            return (next_state, trace[1] + cumulant * discount**timestep)

        return jax.lax.fori_loop(
            0, max_steps, env_step, (source_state, jnp.zeros(self.reward_dim))
        )[1]

    @classmethod
    def from_independent_priors(
        cls,
        seed: int,
        num_states: int,
        transition_prior: Prior,
        cumulant_prior: Prior,
    ):
        key = jax.random.PRNGKey(seed)
        transition_key, cumulant_key = jax.random.split(key)
        transition_keys = jnp.array(jax.random.split(transition_key, num_states))
        transition_kernel = jax.vmap(transition_prior)(transition_keys)
        cumulant_keys = jnp.array(jax.random.split(cumulant_key, num_states))
        cumulants = jax.vmap(cumulant_prior)(cumulant_keys)
        return cls(transition_kernel=transition_kernel, cumulants=cumulants)


def RowlandTwoStateMRP(
    seed: int,
    cumulant_prior: Prior | None = None,
    cumulant_value: chex.Array | None = None,
) -> MarkovRewardProcess:
    """
    MRP derived from the (scalar reward) example given in
    Rowland et. al https://arxiv.org/abs/2301.04462 (Figure 5).
    """
    transition_kernel = jnp.array([[0.6, 0.4], [0.8, 0.2]])
    if cumulant_value is not None:
        cumulants = cumulant_value
    elif cumulant_prior is not None:
        key = jax.random.PRNGKey(seed)
        cumulants = cumulant_prior(key)
    else:
        cumulants = jnp.array([2.0, -1.0])
    return MarkovRewardProcess(transition_kernel=transition_kernel, cumulants=cumulants)


def RowlandTwoStateMRPMultivariateNaive(reward_dim: int) -> MarkovRewardProcess:
    cumulants = (
        jnp.zeros((2, reward_dim))
        .at[0, :]
        .set(2.0 * jnp.ones(reward_dim) * (-1) ** jnp.arange(reward_dim))
        .at[1, :]
        .set(-1.0 * jnp.ones(reward_dim) * (-1) ** jnp.arange(reward_dim))
    )
    return RowlandTwoStateMRP(0, cumulant_value=cumulants)


def ExtrapolatedRowlandMRP() -> MarkovRewardProcess:
    """
    This is basically two instances of RowlandTwoStateMRP in parallel.
    """
    transition_kernel = jnp.array(
        [
            [0.6, 0.0, 0.2, 0.2],
            [0.0, 0.6, 0.2, 0.2],
            [0.4, 0.4, 0.2, 0.0],
            [0.4, 0.4, 0.0, 0.2],
        ]
    )
    cumulants = jnp.array(
        [
            [4.0, 0.0],
            [0.0, 4.0],
            [-1.0, 0.0],
            [0.0, -1.0],
        ]
    )
    return MarkovRewardProcess(transition_kernel=transition_kernel, cumulants=cumulants)


def FiniteHorizonTerminalRewardMRP(
    base_mrp: MarkovRewardProcess, horizon: int
) -> MarkovRewardProcess:
    num_states = base_mrp.num_states * horizon

    def _indices_at_level(level: int):
        start = level * base_mrp.num_states
        end = start + base_mrp.num_states
        return start, end

    transition_kernel = jnp.zeros((num_states, num_states))
    for level in range(horizon - 1):
        src_start, src_end = _indices_at_level(level)
        tgt_start, tgt_end = _indices_at_level(level + 1)
        transition_kernel = transition_kernel.at[
            src_start:src_end, tgt_start:tgt_end
        ].set(base_mrp.transition_kernel)
    last_level_start, last_level_end = _indices_at_level(horizon - 1)
    transition_kernel = transition_kernel.at[
        last_level_start:last_level_end, last_level_start:last_level_end
    ].set(jnp.eye(base_mrp.num_states))
    cumulants = (
        jnp.zeros((num_states, base_mrp.reward_dim))
        .at[last_level_start:last_level_end, :]
        .set(base_mrp.cumulants)
    )
    return MarkovRewardProcess(transition_kernel=transition_kernel, cumulants=cumulants)


def TerminalRewardMRP(
    base_mrp: MarkovRewardProcess, num_terminal: int
) -> MarkovRewardProcess:
    num_states = base_mrp.num_states + 1
    cumulants = (
        jnp.zeros((num_states, base_mrp.reward_dim))
        .at[(base_mrp.num_states - num_terminal) : base_mrp.num_states, :]
        .set(base_mrp.cumulants[-num_terminal:, :])
    )
    transition_kernel = (
        jnp.zeros((num_states, num_states))
        .at[: base_mrp.num_states, : base_mrp.num_states]
        .set(base_mrp.transition_kernel)
        .at[base_mrp.num_states - num_terminal : num_states, :]
        .set(einops.repeat(jnp.eye(num_states)[-1], "n -> m n", m=num_terminal + 1))
    )
    return MarkovRewardProcess(transition_kernel=transition_kernel, cumulants=cumulants)
