"""
This file implements the third environment from the paper "Convergence Results
for Some Temporal Difference Methods Based on Least Squares" by Yu and
Bertsekas. See page 11.
"""

import chex
import jax
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple, Tuple
from jax import Array


class GameState(NamedTuple):
    state: Array = jnp.int32(0) # Start at state zero, just because.


class ExampleThree():
    """
    100-state Markov chain that has a random walk structure. The transition
    function obeys: P(i|i) = 0.1, P(i + 1|i) = 0.45 = P(i - 1|i); for all i in
    [1, 98]. For i = 0: P(0|0) = 0.1, P(2|1) = 0.9, and similarly for state 99.

    The costs are computed as follows:
        - if(i < 89): cost ~ uniform(0, 1)
        - if(i >= 89): cost ~ uniform(0, 1) + i/30

    The observations are one-hot vectors representing the current state.
    """
    num_states: Array = jnp.int32(100)
    num_actions: Array = jnp.int32(3)

    def __init__(self, rng_key: chex.PRNGKey):
        self.embedding_matrix = jax.random.normal(rng_key, shape=(100, 3))
        self.transition_matrix = jnp.zeros(shape=(100,100))
        self.transition_matrix = self.transition_matrix.at[0, 0:2].set(jnp.array([0.1, 0.9]))
        for i in range(98):
            self.transition_matrix = self.transition_matrix.at[i+1, i:i+3].set(jnp.array([0.45, 0.1, 0.45]))
        self.transition_matrix = self.transition_matrix.at[99,98:].set(jnp.array([0.9, 0.1]))

    def init(self) -> GameState:
        return GameState()
    
    def step(
            self,
            state: GameState,
            rng_key: chex.PRNGKey,
        ) -> Tuple[GameState, Array, chex.PRNGKey]:
        rng_key, rand_key = jax.random.split(rng_key, 2)
        rand = jax.random.uniform(rand_key, shape=())
        action = jnp.select(
            condlist=[rand < 0.45, rand < 0.55, rand <= 1],
            choicelist=[-1, 0, 1],
        )
        # Check for boundary values.
        action = jnp.where(
            jnp.logical_and(action == -1, state.state == 0),
            1,
            action,
        )
        action = jnp.where(
            jnp.logical_and(action == 1, state.state == 99),
            -1,
            action,
        )
        prob = self.transition_matrix[state.state, action]
        state = state._replace(
            state=state.state + action,
        )
        return (state, prob, rng_key)
    
    def obs(self, state: GameState) -> Array:
        obs = self.embedding_matrix[state.state]
        return obs.reshape(self.num_actions, 1)

    def costs(
            self,
            state: GameState,
            rng_key: chex.PRNGKey,
        ) -> Tuple[Array, chex.PRNGKey]:
        rng_key, rand_key = jax.random.split(rng_key, 2)
        rand = jax.random.uniform(rand_key, shape=())
        cost = jnp.select(
            condlist=[state.state > 89],
            choicelist=[rand + state.state / 30],
            default=rand,
        )
        return (cost, rng_key)


if __name__ == "__main__":
    key = jax.random.key(123)
    
    # Test 1: just do some steps in the environment with a known key.
    env = ExampleThree(key)
    state = env.init()
    jitted_step = jax.jit(env.step)
    for _ in range(5000):
        state, action, key = jitted_step(state, key)
    assert state.state == 76, "Random walk failed! Did not walk deterministically enough :o"

    # Test 2: do a bunch of actually random walks.
    key = jax.random.key(
        np.random.randint(low=0, high=np.iinfo(np.int32).max))
    state = env.init()
    for _ in range(5000):
        state, action, key = jitted_step(state, key)
    
    # Test 3: ensure vmap() works.
    key = jax.random.key(
        np.random.randint(low=0, high=np.iinfo(np.int32).max))
    key, env_key, step_key = jax.random.split(key, 3)
    env_keys = jax.random.split(env_key, 100)
    states = jax.vmap(ExampleThree.init)(env_keys)
    del env_key
    del env_keys

    step_keys = jax.random.split(step_key, 100)
    del step_key

    vmapped_step = jax.vmap(env.step)
    vmapped_costs = jax.vmap(env.costs)
    for _ in range(50):
        states, actions, cost_keys = vmapped_step(states, step_keys)
        costs, step_keys = vmapped_costs(states, cost_keys)
    
    # More tests can be added for the correctness of the costs, but I will not
    # do this.
    print("All tests passed!")

