from typing import *
from .modelbasedagent import ModelBasedAgent
import numpy as np

class RMAXAgent(ModelBasedAgent):
    """Runs R-MAX only for an MDP, i.e., not a stochastic game, in order to simplify data structures."""
    def __init__(self, min_visit_count, use_jax: bool, env_name: str, use_max_reward=True, **kwargs):
        super(RMAXAgent, self).__init__(**kwargs)
        self.min_visit_count = min_visit_count
        self.use_max_reward = use_max_reward

        self.use_jax = use_jax

        if use_max_reward:
            self.Rmax = self.max_reward # arbitrarily set (!)
        else:
            self.Rmax = 50 # arbitrarily set (!)
            
        self.reward = np.ones((self.num_states+1, self.num_actions, self.num_states+1), dtype=np.float32) * self.Rmax
        self.transition_observations = np.zeros((self.num_states+1, self.num_actions, self.num_states+1), dtype=np.float32)
        self.value_table = np.zeros((self.num_states+1, self.num_actions), dtype=np.float32)

        if env_name == "Bipolar" or env_name == "GridWorld":
            up_row = np.arange(9)
            bottom_row = 8 * 9 + np.arange(9)
            left_col = np.arange(9) * 9
            right_col = np.arange(9) * 9 + 8
        
            wall_index = np.concatenate([up_row, bottom_row, left_col, right_col])
            goal_index = np.array([10, 70])
            
            self.terminal_indexes = np.concatenate([wall_index, goal_index]) + 1
        elif env_name == "LazyChain":
            self.terminal_indexes = np.array([1, -1])
        else:
            raise NotImplementedError
        
    def reset(self):
        super(RMAXAgent, self).reset()        
        self.reward.fill(self.Rmax)
        self.transition_observations.fill(0)
        self.value_table.fill(0)

    def update_model(self, reward, next_state):
        if reward is not None:
            # Update the reward associated with (s,a,s') if first time.
            if self.reward[self.last_state+1, self.last_action, next_state+1] == self.Rmax:
                self.reward[self.last_state+1, self.last_action, next_state+1] = reward
                if not self.use_max_reward and self.Rmax < reward:
                    self.reward[self.reward == self.Rmax] = reward
                    self.Rmax = reward

            # Update set of states reached by playing a.
            self.transition_observations[self.last_state+1, self.last_action, next_state+1] += 1

    def interact(self, reward, next_state):
        # update model
        self.update_model(reward, next_state)

        # Compute new optimal T-step policy if reach min_visit_count or finished executing previous one
        if self.policy_step == self.T or self.transition_observations[self.last_state+1, self.last_action].sum() == self.min_visit_count:
            self._compute_policy()

        # Choose next action according to policy.
        next_action = self._argmax_breaking_ties_randomly(self.value_table[next_state+1])

        self.policy_step += 1
        self.last_state = next_state
        self.last_action = next_action

        return next_action

    def _compute_policy(self):
        """Compute an optimal T-step policy for the current state."""
        self.policy_step = 0
        # Obtain transition probabilities (prevent dividing by zero).
        divisor = self.transition_observations.sum(axis=2, keepdims=True)
        divisor[divisor == 0] = 1
        transition_probs = self.transition_observations / divisor
        # Replace all state-action pairs with zero probability everywhere, i.e.,
        # no counts, with probability 1 to the fictitious game state.
        
        #eps = 1e-5
        #for s in range(self.num_states+1):
        #    for a in range(self.num_actions):
        #        if -eps < transition_probs[s,a].sum() < eps:
        #            assert transition_probs[s,a].sum() == 0, f"transition_probs[s,a].sum() != 0 {transition_probs[s,a].sum()}"
        #            transition_probs[s, a, 0] = 1

        indice_non_visited = transition_probs.sum(axis=-1) == 0
        transition_probs[indice_non_visited, 0] = 1
        
        # transition_probs[1] = 0
        # transition_probs[-1] = 0
        # transition_probs[1, :, 1] = 1
        # transition_probs[-1, :, -1] = 1

        # self.reward[1] = 0
        # self.reward[-1] = 0

        transition_probs[self.terminal_indexes] = 0
        transition_probs[self.terminal_indexes, :, self.terminal_indexes] = 1

        self.reward[self.terminal_indexes] = 0
        
        if self.use_jax:
            self.jax_value_iteration(self.reward, transition_probs)
        else:
            self._value_iteration(self.reward, transition_probs)
