import jax
import jax.numpy as jnp
import numpy as np
from . import ChainWorld

class ChainWorldSolver(object):
    def __init__(self, num_states, num_actions, discount_factor):
        self.num_states = num_states
        self.num_actions = num_actions
        self.discount_factor = discount_factor

    def get_transition_matrix(self, tp="numpy"):
        if tp == "numpy":
            transition_probs = np.zeros((self.num_states, self.num_actions, self.num_states))
            transition_probs[np.arange(1, self.num_states - 1, 1), 0, np.arange(0, self.num_states - 2, 1)] = 1
            transition_probs[0, :, 0] = 1

            transition_probs[np.arange(1, self.num_states - 1, 1), 1, np.arange(2, self.num_states, 1)] = 1
            transition_probs[self.num_states - 1, :, self.num_states - 1] = 1

            transition_probs[np.arange(1, self.num_states - 1, 1), 2, np.arange(1, self.num_states - 1, 1)] = 1
            transition_probs[self.num_states - 1, :, self.num_states - 1] = 1

            assert (np.sum(transition_probs, axis=2) == 1).all(), f"transition model ill-posed!"

        elif tp == "jax":
            # Initialize the transition_probs array with zeros
            transition_probs = jnp.zeros((self.num_states, self.num_actions, self.num_states))

            # Update transition_probs using jnp.array.at for indexed updates
            transition_probs = transition_probs.at[
                jnp.arange(1, self.num_states - 1, 1), 0, jnp.arange(0, self.num_states - 2, 1)
            ].set(1)
            transition_probs = transition_probs.at[0, :, 0].set(1)

            transition_probs = transition_probs.at[
                jnp.arange(1, self.num_states - 1, 1), 1, jnp.arange(2, self.num_states, 1)
            ].set(1)
            transition_probs = transition_probs.at[self.num_states - 1, :, self.num_states - 1].set(1)

            transition_probs = transition_probs.at[
                jnp.arange(1, self.num_states - 1, 1), 2, jnp.arange(1, self.num_states - 1, 1)
            ].set(1)
            transition_probs = transition_probs.at[self.num_states - 1, :, self.num_states - 1].set(1)

            assert (jnp.sum(transition_probs, axis=2) == 1).all(), f"transition model ill-posed!"

        return transition_probs

    def get_reward(self, tp="numpy"):
        if tp == "numpy":
            reward = np.zeros((self.num_states, self.num_actions, self.num_states))
            reward[np.arange(1, self.num_states - 1, 1), 0, np.arange(0, self.num_states - 2, 1)] = -1
            reward[1, 0, 0] = (self.num_states - 3) // 2

            reward[np.arange(1, self.num_states - 1, 1), 1, np.arange(2, self.num_states, 1)] = -1
            reward[self.num_states - 2, 1, self.num_states - 1] = self.num_states - 2
        elif tp == "jax":
            reward = jnp.zeros((self.num_states, self.num_actions, self.num_states))
            reward = reward.at[np.arange(1, self.num_states - 1, 1), 0, np.arange(0, self.num_states - 2, 1)].set(-1)
            reward = reward.at[1, 0, 0].set((self.num_states - 3) // 2)

            reward = reward.at[jnp.arange(1, self.num_states - 1, 1), 1, jnp.arange(2, self.num_states, 1)].set(-1)
            reward = reward.at[self.num_states - 2, 1, self.num_states - 1].set(self.num_states - 2)

        return reward

    def solve_optimal_value_function_numpy(self):
        transition_probs = self.get_transition_matrix()

        value_dim = transition_probs.shape[0]
        value = np.zeros(value_dim)

        rewards = self.get_reward()
        # rewards = np.ones((self.num_states, self.num_actions, self.num_states)) * 30
        rewards[0] = 0
        rewards[-1] = 0

        value_table = np.zeros((self.num_states, self.num_actions))

        k = 0
        while True:
            diff = 0
            for s in range(value_dim):
                old = value[s]
                value[s] = np.max(np.sum(transition_probs[s]*(rewards[s] +
                           self.discount_factor*np.array([value,]*self.num_actions)),
                           axis=1))
                diff = max(diff, abs(old - value[s]))
            k += 1
            if diff < 1e-2:
                break
            if k > 1e6:
                raise Exception("Value iteration not converging. Stopped at 1e6 iterations.")
        for s in range(value_dim):
            value_table[s] = np.sum(transition_probs[s]*(rewards[s] +
                   self.discount_factor*np.array([value,]*self.num_actions)),
                   axis=1)

        return value_table, value

    def solve_optimal_value_function_jax(self):
        transition_probs = self.get_transition_matrix("jax")
        value_dim = transition_probs.shape[0]

        # Initialize rewards
        rewards = self.get_reward("jax")
        # rewards = jnp.ones((self.num_states, self.num_actions, self.num_states)) * 30
        rewards = rewards.at[0].set(0)
        rewards = rewards.at[-1].set(0)

        # Initialize value function
        value = jnp.zeros(value_dim)

        def value_iteration_step(value):
            # Compute Q-values for all states and actions
            q_values = jnp.sum(transition_probs * (rewards + self.discount_factor * value[jnp.newaxis, :]), axis=2)
            # Update value function
            new_value = jnp.max(q_values, axis=1)
            diff = jnp.max(jnp.abs(new_value - value))
            return new_value, diff

        def cond_fn(state):
            _, diff, k = state
            return (diff > 1e-2) & (k < 1e6)

        def body_fn(state):
            value, _, k = state
            new_value, diff = value_iteration_step(value)
            return new_value, diff, k + 1

        # Iterate using JAX's while_loop
        value, _, k = jax.lax.while_loop(cond_fn, body_fn, (value, jnp.inf, 0))

        # Compute final value table for optimal policy
        q_values = jnp.sum(transition_probs * (rewards + self.discount_factor * value[jnp.newaxis, :]), axis=2)

        return q_values, value
       

def main():
    env = ChainWorld()
    solver = ChainWorldSolver(env.num_states, env.num_actions, discount_factor)
    print(" --- numpy sovler ---")
    q1, v1 = solver.solve_optimal_value_function_numpy()
    print(f"Q values {q1}")
    print(f"V values {v1}")
    
    print("\n --- jax sovler ---")
    q2, v2 = solver.solve_optimal_value_function_jax()
    print(f"Q values {q2}")
    print(f"V values {v2}")

    
if __name__ == '__main__':
    main()
