import gym
from gym import spaces
from collections import defaultdict 
import numpy as np
from tqdm import tqdm


class Dynamics:
    """Base class for dynamics models"""

    def __init__(self):
        pass

    def get_model(self, *args, **kwargs):
        pass

    def update(self, *args, **kwargs):
        pass

class Transition_Matrix(Dynamics):
    """"Class decorator for transition matrices"""

    def __init__(self, matrix):
        self.matrix = matrix

    def get_model(self):
        return self.matrix

class True_Successor(Dynamics):

    def __init__(self, matrix):
        self.matrix = matrix

        self.n_states = matrix.shape[1]
        self.n_actions = matrix.shape[2]

        self.successor_state_matrix = np.arange(self.n_states, dtype=np.int32).reshape(1, -1)
        self.successor_counts = np.zeros((1, self.n_states, self.n_actions), dtype=np.float32)

        self._initialize()

    def _initialize(self):

        for state in tqdm(range(self.n_states)):
            for action in range(self.n_actions):
                indices = np.where(self.matrix[:, state, action] > 0.0)[0]
                for idx in indices:
                    self._update(idx, state, action, self.matrix[idx, state, action])

    def _update(self, next_state, state, action, prob):

        try:
            index = np.where(self.successor_state_matrix[:, state] == next_state)[0][0]
            self.successor_counts[index, state, action] += prob
        except IndexError:
            try:
                empty_index = np.where(self.successor_state_matrix[:, state] == -1)[0][0]
                self.successor_state_matrix[empty_index, state] = next_state
                self.successor_counts[empty_index, state, action] += prob
            except IndexError:
                self.successor_state_matrix = np.vstack([self.successor_state_matrix, -1 * np.ones(self.n_states)])
                self.successor_counts = np.vstack([self.successor_counts, np.zeros((self.n_states, self.n_actions))[np.newaxis, :, :]])
                self.successor_state_matrix[-1, state] = next_state
                self.successor_counts[-1, state, action] += prob

    def get_model(self):        
        return (self.successor_state_matrix, self.successor_counts / np.sum(self.successor_counts, axis=0))

class Tabular_Dynamics_Successor(Dynamics):

    """Implements a simple tabular dynamics models using successor states and counts, 
       rather than comouting the full matrix.

    Input attributes:
        n_states: the number of states
        n_actions: the number of actions
        prior: whether to specify a prior over the counts or not

    Other attributes:
        successor_states: dictionary mapping each state to the observed successor states
        successor_counts: dictionary mapping each state action pair to the next state pseudo counts
    """

    def __init__(self, n_states, n_actions, prior=None):
        self.n_states = n_states
        self.n_actions = n_actions


        self.successor_state_matrix = np.arange(self.n_states, dtype=np.int32).reshape(1, -1)
        self.successor_counts = np.ones((1, self.n_states, self.n_actions), dtype=np.float32)

    def update(self, next_state, state, action):

        try:
            index = np.where(self.successor_state_matrix[:, state] == next_state)[0][0]
            self.successor_counts[index, state, action] += 1.0
        except IndexError:
            try:
                empty_index = np.where(self.successor_state_matrix[:, state] == -1)[0][0]
                self.successor_state_matrix[empty_index, state] = next_state
                self.successor_counts[empty_index, state, action] += 1.0
            except IndexError:
                self.successor_state_matrix = np.vstack([self.successor_state_matrix, -1 * np.ones(self.n_states)])
                self.successor_counts = np.vstack([self.successor_counts, np.zeros((self.n_states, self.n_actions))[np.newaxis, :, :]])
                self.successor_state_matrix[-1, state] = next_state
                self.successor_counts[-1, state, action] += 1.0

    def get_model(self):        
        return (self.successor_state_matrix, self.successor_counts / np.sum(self.successor_counts, axis=0))

class Tabular_Dynamics(Dynamics):

    """Implements a simple tabular dynamics models, stores the vistsation counts 
       and returns an approximate model of the environment dynamics (using maximum likelihood)

    Input attributes:
        n_states: the number of states
        n_actions: the number of actions
        prior: whether to specify a prior over the counts or not

    Other attributes:
        pseudo_counts: the number of times [s' s  a] has been observed 
    """
    def __init__(self, n_states, n_actions, prior=None):
        self.n_states = n_states
        self.n_actions = n_actions
        if prior is None:
            self.pseudo_counts = np.zeros((n_states, n_states, n_actions))
        else:
            assert prior.shape == (self.n_states, self.n_states, self.n_actions,)
            self.pseudo_counts = prior

    def get_model(self):
        """return the maximum likelihood model"""
        return self._get_empirical_probs()

    def update(self, next_state, state, action):
        """update the model with some experience"""
        self._update_counts(next_state, state, action)
        
    def _get_empirical_probs(self):
        """compute the maxmimum likelihood state action transition matrix"""
        div = np.sum(self.pseudo_counts, axis=0)
        div[np.where(div==0.0)]=1.0
        return self.pseudo_counts / div

    def _update_counts(self, next_state, state, action):
        """update the pseudo counts with some experience"""
        self.pseudo_counts[next_state, state, action] += 1.0

        