import gym
from gym import spaces
from collections import defaultdict 
import numpy as np


class Dynamics:
    """Base class for dynamics models"""

    def __init__(self):
        pass

    def get_model(self, *args, **kwargs):
        pass

    def update(self, *args, **kwargs):
        pass

class Transition_Matrix(Dynamics):
    """"Class decorator for transition matrices"""

    def __init__(self, matrix):
        self.matrix = matrix

    def get_model(self):
        return self.matrix

class Tabular_Dynamics(Dynamics):

    """Implements a simple tabular dynamics models, stores the vistsation counts 
       and returns an approximate model of the environment dynamics (using maximum likelihood)

    Input attributes:
        n_states: the number of states
        n_actions: the number of actions
        prior: whether to specify a prior over the counts or not

    Other attributes:
        pseudo_counts: the number of times [s' s  a] has been observed 
    """
    def __init__(self, n_states, n_actions, prior=None):
        self.n_states = n_states
        self.n_actions = n_actions
        if prior is None:
            self.pseudo_counts = np.zeros((n_states, n_states, n_actions))
        else:
            assert prior.shape == (self.n_states, self.n_states, self.n_actions,)
            self.pseudo_counts = prior

    def get_model(self):
        """return the maximum likelihood model"""
        return self._get_empirical_probs()

    def update(self, next_state, state, action):
        """update the model with some experience"""
        self._update_counts(next_state, state, action)
        
    def _get_empirical_probs(self):
        """compute the maxmimum likelihood state action transition matrix"""
        div = np.sum(self.pseudo_counts, axis=0)
        div[np.where(div==0.0)]=1.0
        return self.pseudo_counts / div

    def _update_counts(self, next_state, state, action):
        """update the pseudo counts with some experience"""
        self.pseudo_counts[next_state, state, action] += 1.0

        