'''
Implementation of Bayesian models
'''
import time

import numpy as np
from scipy.special import gammaln, softmax
from scipy.stats import gmean
from scipy.linalg import block_diag
from scipy.optimize import Bounds
from scipy.optimize import LinearConstraint
from collections import Counter
#-------------------------------------------------------------------------------


class BayesianModel:
    '''
    A Bayesian model consists of a prior and a likelihood function.
    Some useful methods are implemented to support posterior sampling
    and Bayesian optimistic optimization.
    '''

    def __init__(self, prior_param,batch_size=0,seed=0):
        '''
        Initialize a new Bayesian model.

        Args:
            prior_param - numpy.ndarray - parameter of the prior distribution
            batch_size - int - the size of the mini batch data.
            seed - int - The seed used to generate the random generator

        Returns:
            a Bayesian model object
        '''
        self.prior_param = prior_param
        self.posterior_param = None
        self.history = [] #S*H*(S+n)
        self.transition_memory=[] #Use to sample mini_batch data
        self.reward_memory = []  # Use to sample mini_batch data
        self.randomgenerator=np.random.default_rng(seed)
        self.batch_size=batch_size
        self.batch_len=0
        self.history_num=0
        self.batch_history=[]
        #Used to store the state buffer
        self.state_buffer=[]
        self.buffer_index=0



    def log_prior_probability(self, model_param,flag=None):
        '''
        Calculate the log prior probability of a particular model

        Args:
            model_param - parameter of a specific model
        '''
        raise NotImplementedError("Subclasses should implement `log_prior_probability`")

    def log_likelihood(self, model_param):
        '''
        Calculate the log likelihood of the entire history
        
        Returns:
            model_param - parameter of a specific model
            log_likelihood - float - log likehood of the history
        '''
        raise NotImplementedError("Subclasses should implement `log_likelihood`")
    
    def log_posterior_probability(self, model_param):
        '''
        Calculate the log posterior probability of a particular model
        
        Returns:
            model_param - parameter of a specific model
        '''
        raise NotImplementedError("Subclasses should implement `log_posterior_probability`")

    def update_history(self, episode):
        '''
        Add a new episode trajectory into the history.
        Update the posterior distribution if possible.

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        self.history.append(episode)
        if self.posterior_param != None:
            self.update_posterior(episode)

    def update_posterior(self, episode):
        '''
        Update the posterior using a new episodic trajectory

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        raise NotImplementedError("Subclasses should implement `update_posterior`")
    
    def map_model(self):
        '''
        Calculate the MAP solution of the posterior

        Returns:
            map_model_param - numpy.ndarray - parameter of the MAP model
        '''
        raise NotImplementedError("Subclasses should implement `update_posterior`")

    def posterior_sampling(self, size=None):
        '''
        Sample from the posterior distribution
        (This is inefficient.)

        Args:
            size - int - sample size
        
        Returns:
            params - numpy.ndarray - samples are stored along the first axis
        '''
        iterations = 1000

        map_theta = self.map_model()

        n_sample = 1 if size == None else size
        samples = np.zeros((n_sample,) + map_theta.shape)
        for i in range(n_sample):
            theta = map_theta
            log_p = self.log_prior_probability(theta)
            assert log_p > -np.inf
            log_p += self.log_likelihood(theta)

            for h in range(iterations):
                new_log_p = -np.inf
                while new_log_p == -np.inf:
                    #new_theta = theta + np.random.normal(size=theta.shape)
                    new_theta = theta + self.randomgenerator.normal(size=theta.shape)
                    new_log_p = self.log_prior_probability(new_theta)
                new_log_p += self.log_likelihood(new_theta)
                #if np.random.rand() < np.exp(log_p - new_log_p):
                if self.randomgenerator.random() < np.exp(log_p - new_log_p):
                    theta = new_theta
                    log_p = new_log_p
            
            samples[i] = theta
        
        return samples[0] if size == None else samples
    
    def hpd_func(self, model_param):
        '''
        The hpd function is an incresing function of the posterior probability

        Args:
            model_param - parameter of a specific model
        
        Returns:
            hpd_value - float - the value of the HPD function at the specific model parameter
        '''
        raise NotImplementedError("Subclasses should implement `hpd_func`")
    
    def hpd_constraints(self, hpd_thres):
        '''
        Used in optimization. Specify the HPD region.

        Args:
            hpd_thres - float - the HPD threshold
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        raise NotImplementedError("Subclasses should implement `hpd_constraints`")

    def monte_carlo_quantile(self, func, quantile, sample_size):
        """
        Compute the Monte Carlo quantile of posterior random variable transformed by a function

        Args:
            func - function - a function that transforms a random variable
            quantile - double or list of double - quantile to compute
            sample_size - int - number of samples to draw

        Returns:
            double or list of double - the quantile value
        """
        samples = self.posterior_sampling(sample_size)
        f_samples = np.array([func(s) for s in samples])
        return np.quantile(f_samples, quantile, axis=0, method='median_unbiased')
    
    def param_feasible_region(self):
        '''
        Used in optimization. Specify the feasible region of the model parameter.
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        raise NotImplementedError("Subclasses should implement `param_feasible_region`")


    def sample_reward(self, model_param, state, action, size=None):
        '''
        Sample a reward given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None - number of reward samples
        
        Returns:
            reward - float - a single reward if `size=None` or a list of rewards if otherwise
        '''
        raise NotImplementedError("Subclasses should implement `sample_reward`")
    

    def sample_next_state(self, model_param, state, action, size=None):
        '''
        Sample the next state given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None- number of next state samples
        
        Returns:
            next_state - a single state if `size=None` or a list of states if otherwise
        '''
        raise NotImplementedError("Subclasses should implement `sample_next_state`")

    @classmethod
    def to_tabular_MDP(cls, model_param):
        '''
        Convert a finite MDP to a deterministic reward tabular MDP.

        Args:
            model_param - parameter of a specific model
        
        Returns:
            tabularMDP - tuple - tuple of S*A*S transition matrix and S*A reward matrix
        '''

        trans = np.concatenate((model_param[:, :, :-1], 1.0 - np.sum(model_param[:, :, :-1], axis=-1, keepdims=True)), axis=-1)
        rewards = model_param[:, :, -1]

        return (trans, rewards)

    @classmethod
    def mean_reward_gradient(cls, model_param, state, action):
        '''
        The gradient of the mean reward function at the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `mean_reward_gradient`")

    @classmethod
    def log_transition_gradient(cls, model_param, state, action, next_state):
        '''
        The gradient of the log transition function at the given state-action-next_state pair

        Args:
            model_param - parameter of a specific model
            state
            action
            next_state
        
        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `log_transition_gradient`")

    
    def log_reward_gradient(self, model_param, state, action):
        '''
        The gradient of the log reward function at the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `log_reward_gradient`")

    def value_gradient(self,model_param,occupancy_measure,value_weighted_occupancy_measure):
        '''
        The gradient of the value

        Args:
            model_param - parameter of a specific model
            occupancy_measure
            value_weighted_occupancy_measure
        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `value_gradient`")
    def log_prior_transition_gradient(self, model_param):
        '''
        The gradient of the log transition function at the given prior

        Args:
            model_param - parameter of a specific model


        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `log_prior_transition_gradient`")

    
    def log_prior_reward_gradient(self, model_param):
        '''
        The gradient of the log reward function at the given prior

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `log_prior_reward_gradient`")

    def log_posterior_gradient(self, model_param):
        '''
        The gradient of the log posterior probability

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        raise NotImplementedError("Subclasses should implement `log_posterior_gradient`")

    @classmethod
    def reward_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the reward distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''
        raise NotImplementedError("Subclasses should implement `reward_fisher_information`")

    @classmethod
    def transition_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the transition distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''
        raise NotImplementedError("Subclasses should implement `transition_fisher_information`")

    @classmethod
    def model_log_likelihood(cls, std_param, history, flag=None):
        '''
        Calculate the log probability of a particular model in given distribution

        Args:
            std_param - standard parameter of a specific model
            history - the observation history
        '''
        raise NotImplementedError("Subclasses should implement `model_log_likelihood`")


#-------------------------------------------------------------------------------
# Bayesian Tabular Model

class BayesianTabularModel(BayesianModel):
    '''
    The Bayesian tabular model assumes finite state and action spaces.
    The reward and transition of all state-action pairs are made independent.
    Meanwhile, the reward is assumed to follow a Bernoulli distribution.

    The parameter is an S*A*S matrix, where the last dimentionality of 
    the last axis represent the mean reward, and other dimensitionalities
    determine the transition probability to the first S-1 states.
    '''

    def __init__(self, prior_param,batch_size=0,seed=0):
        '''
        Initialize a new Bayesian tabular model.
        The prior parameter is an S*A*(S+2) matrix, where the last two dimensionality
        of the last axis denote the Beta prior of the reward, and the others represent
        the Dirichlet prior of the transition.

        Args:
            prior_param - numpy.ndarray - parameter of the prior distribution

        Returns:
            a Bayesian model object
        '''
        super().__init__(prior_param,batch_size,seed)
        self.nState, self.nAction, _ = prior_param.shape
        self.posterior_param = np.copy(self.prior_param)
        self.history = np.zeros_like(self.prior_param)
    
    @classmethod
    def to_standard_param(self, model_param):
        '''
        Convert the model_param matrix into the standard parameters
        of the transition and reward distributions. The standard transition
        parameter is an S*A*S matrix specifying the full transition probability.
        The standard reward parameter is an S*A*2 matrix specifying the
        probability of receving rewards 1 and 0. These two matrices are
        concatenated along the last axis.

        Args:
            model_param - numpy.ndarray - parameter matrix of size S*A*S
        
        Returns:
            std_param - numpy.ndarray - parameter matrix of size S*A*(S+2)
        '''
        std_param = np.concatenate((model_param[:, :, :-1], 1.0 - np.sum(model_param[:, :, :-1], axis=-1, keepdims=True),
                                    model_param[:, :, -1:], 1.0 - model_param[:, :, -1:]), axis=-1)
        std_param = np.clip(std_param, 1e-9, 1.0 - 1e-9)
        return std_param

    def log_prior_probability(self, model_param,ignore_normalfactor=None):
        '''
        Calculate the log prior probability of a particular model

        Args:
            model_param - parameter of a specific model

        Returns:
            log_probability - float - log probability of the model in the prior
        '''
        std_param = self.to_standard_param(model_param)
        return self.model_log_probability(std_param, self.prior_param,ignore_normalfactor)
    
    def log_likelihood(self, model_param,ignore_normalfactor=True):
        '''
        Calculate the log likelihood of the entire history
        
        Args:
            model_param - parameter of a specific model

        Returns:
            log_likelihood - float - log likehood of the history
        '''
        std_param = self.to_standard_param(model_param)
        return self.model_log_likelihood(std_param, self.history,ignore_normalfactor)
    
    def log_posterior_probability(self, model_param,mini_batch=False,ignore_normalfactor=True):
        '''
        Calculate the log posterior probability of a particular model

        Args:
            model_param - parameter of a specific model
            mini_batch - Use mini_batch gradient decent or not
            ignore_normalfactor - Normalizing factors in probability distributions are ignored or not
        
        Returns:
            model_param - parameter of a specific model
        '''
        std_param = self.to_standard_param(model_param)
        if mini_batch==False:
            # Lop p of the model under the whole history
            return self.model_log_probability(std_param, self.posterior_param,ignore_normalfactor)
        else:
            # Lop p of the model under the mini_batch data
            log_prior=self.model_log_probability(std_param, self.prior_param,ignore_normalfactor)
            log_likelyhood=self.model_log_likelihood(std_param,self.batch_history,ignore_normalfactor)
            return log_likelyhood*self.history_num/(self.batch_len)+log_prior

    def update_history(self, episode):
        '''
        Add a new episode trajectory into the history.
        Update the posterior distribution if possible.

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        for (s, a, r, sp, done) in episode:
            if not done:
                self.history[s, a, sp] += 1

            self.history[s, a, self.nState + (r == 0)] += 1

        self.update_posterior(episode)

    def update_posterior(self, episode):
        '''
        Update the posterior using a new episodic trajectory

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        for (s, a, r, sp, done) in episode:
            if not done:
                self.posterior_param[s, a, sp] += 1
            self.posterior_param[s, a, self.nState + (r == 0)] += 1
    
    def map_model(self):
        '''
        Calculate the MAP solution of the posterior

        Returns:
            map_model_param - numpy.ndarray - parameter of the MAP model
        '''
        temp_param = self.posterior_param - 1.0 + 1e-9
        trans_normalizer = np.sum(temp_param[:, :, :-2], axis=-1, keepdims=True)
        reward_normalizer = np.sum(temp_param[:, :, -2:], axis=-1, keepdims=True)
        return np.concatenate((temp_param[:, :, :-3] / trans_normalizer, temp_param[:, :, -2] / reward_normalizer), axis=-1)#fixed bug

    def posterior_sampling(self, size=None):
        '''
        Sample from the posterior distribution

        Args:
            size - int - sample size
        
        Returns:
            params - numpy.ndarray - samples are stored along the first axis
        '''
        n_sample = 1 if size == None else size
        samples = np.zeros((n_sample, self.nState, self.nAction, self.nState))

        for s in range(self.nState):
            for a in range(self.nAction):
                # samples[:, s, a, :] = np.random.dirichlet(self.posterior_param[s, a, :-2], n_sample)
                # samples[:, s, a, -1] = np.random.beta(self.posterior_param[s, a, -2], self.posterior_param[s, a, -1], n_sample)
                samples[:, s, a, :] = self.randomgenerator.dirichlet(self.posterior_param[s, a, :-2], n_sample)
                samples[:, s, a, -1] = self.randomgenerator.beta(self.posterior_param[s, a, -2], self.posterior_param[s, a, -1], n_sample)
        return samples[0] if size == None else samples
    
    def hpd_func(self, model_param):
        '''
        The hpd function is an incresing function of the posterior probability

        Args:
            model_param - parameter of a specific model
        
        Returns:
            hpd_value - float - the value of the HPD function at the specific model parameter
        '''

        std_param = self.to_standard_param(model_param)
        posterior_param = self.posterior_param - 1.0

        hpd_value = np.sum(np.log(std_param) * posterior_param) / np.sum(posterior_param)
        return hpd_value

    def independent_hpd_func(self, model_param):
        '''
        The independent hpd function is a tensor-valued function.
        It calculates the hpd value for the transition function
        of each state action pair independently. For the reward
        function, the negative mean reward is used, whose lower
        quantile is just the negative upper quantile of the mean
        reward.

        Args:
            model_param - parameter of a specific model
        
        Returns:
            hpd_value - float - the value of the HPD function at the specific model parameter
        '''

        std_param = self.to_standard_param(model_param)

        trans_param = std_param[:, :, :-2]
        mean_reward = std_param[:, :, -2]
        posterior_dirichlet_param = self.posterior_param[:, :, :-2] - 1.0 + 1e-9

        trans_hpd_value = gmean(trans_param, weights=posterior_dirichlet_param, axis=-1)

        return np.array([trans_hpd_value, - mean_reward])
    
    def hpd_constraints(self, hpd_thres):
        '''
        Used in optimization. Specify the HPD region.

        Args:
            hpd_thres - float - the HPD threshold
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        posterior_param = self.posterior_param - 1.0
        posterior_param /= np.sum(posterior_param)

        def hpd_jacobian(model_param):
            '''
            Calculate the jacobian of the HPD function

            Args:
                model_param - parameter of a specific model
            
            Returns:
                jacobian - numpy.ndarray - jacobian of the HPD function
            '''
            std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
            inverse_std_param = 1.0 / std_param
            gradient = inverse_std_param * posterior_param
            jacobian = np.zeros((self.nState, self.nAction, self.nState))
            jacobian[:, :, :-1] = gradient[:, :, :-3] - gradient[:, :, -3:-2]
            jacobian[:, :, -1] = gradient[:, :, -2] - gradient[:, :, -1]

            return jacobian.reshape(1, -1)

        return [{
                'type': 'ineq',
                'fun': lambda param : self.hpd_func(param.reshape(self.nState, self.nAction, -1)) - hpd_thres,
                'jac': hpd_jacobian
                },]
    
    def param_feasible_region(self):
        '''
        Used in optimization. Specify the feasible region of the model parameter.

        Returns:
            constraints - list of constraint used for optimization
        '''
        bounds = Bounds(np.full((self.nState * self.nAction * self.nState), 1e-9),
                        np.full((self.nState * self.nAction * self.nState), 1.0 - 1e-9))
                       
        grad = np.full(self.nState, -1.0)
        grad[-1] = 0
        grad = self.nState * self.nAction * [grad]
        jacobian = block_diag(*grad)
        constraints =  [{
                'type': 'ineq',
                'fun': lambda param : 1.0 - np.sum(param.reshape(self.nState, self.nAction, -1)[:, :, :-1], axis=-1).reshape(-1),
                'jac': lambda param: jacobian
                },]
        
        return bounds, constraints


    def sample_reward(self, model_param, state, action, size=None):
        '''
        Sample a reward given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None - number of reward samples
        
        Returns:
            reward - float - a single reward if `size=None` or a list of rewards if otherwise
        '''

        return self.randomgenerator.binomial(1, model_param[state, action, -1], size=size)

    

    def sample_next_state(self, model_param, state, action, size=None):
        '''
        Sample the next state given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None- number of next state samples
        
        Returns:
            next_state - a single state if `size=None` or a list of states if otherwise
        '''

        return self.randomgenerator.choice(model_param.shape[0], size=size, p=model_param[state, action, :-1])

    @classmethod
    def to_tabular_MDP(cls, model_param):
        '''
        Convert a finite MDP to a deterministic reward tabular MDP.

        Args:
            model_param - parameter of a specific model
        
        Returns:
            tabularMDP - tuple - tuple of S*A*S transition matrix and S*A reward matrix
        '''

        trans_param = model_param[:, :, :-1]
        reward_param = model_param[:, :, -1]
        trans = np.concatenate((trans_param, 1.0 - np.sum(trans_param, axis=-1, keepdims=True)), axis=-1)
        rewards = reward_param

        return (trans, rewards)

    @classmethod
    def mean_reward_gradient(cls, model_param, state, action):
        '''
        The gradient of the mean reward function at the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action

        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        grad[state, action, -1] = 1.0
        return grad

    @classmethod
    def log_transition_gradient(cls, model_param, state, action, next_state):
        '''
        The gradient of the log transition function at the given state-action-next_state pair

        Args:
            model_param - parameter of a specific model
            state
            action
            next_state

        Returns:
            grad - numpy.ndarray - gradient
        '''

        trans_param = model_param[:, :, :-1]
        grad = np.zeros_like(model_param)
        if next_state == model_param.shape[0] - 1:
            grad[state, action, :-1] = -1.0 / (1.0 - np.sum(trans_param[state,action,:]+1e-9))#fixed bug
        else:
            grad[state, action, next_state] = 1.0 / (trans_param[state, action, next_state]+1e-9)
        return grad

    def value_gradient(self,model_param,occupancy_measure,value_weighted_occupancy_measure):
        '''
        The gradient of the value

        Args:
            model_param - parameter of a specific model
            occupancy_measure
            value_weighted_occupancy_measure
        Returns:
            grad - numpy.ndarray - gradient
        '''
        jacobian = np.zeros_like(model_param)
        for s in range(self.nState):
            for a in range(self.nAction):
                jacobian[s,a,-1] -= occupancy_measure[s, a] * 1.0
                for sp in range(self.nState):
                    if sp == self.nState - 1:
                        jacobian[s, a, :-1] -= value_weighted_occupancy_measure[s, a, sp] *\
                        (-1.0 / (1.0 - np.sum(model_param[s,a,:-1]+1e-9)))#fixed bug
                    else:
                        jacobian[s, a, sp] -= value_weighted_occupancy_measure[s, a, sp] * \
                                              (1.0 / (model_param[s, a, sp]+1e-9))
        return jacobian.reshape(-1)
    def log_reward_gradient(self, model_param, state, action):
        '''
        The gradient of the log reward function at the given state-action pair
        The reward is assumed to follow a Bernoulli distribution

        Args:
            model_param - parameter of a specific model
            state
            action
        Returns:
            grad - numpy.ndarray - gradient
        '''
        
        grad=np.zeros_like(model_param)
        grad[state,action,-1]=self.history[:,:,-2].sum()/(model_param[state,action,-1]+1e-9)-self.history[:,:,-1].sum()/(1-model_param[state,action,-1]+1e-9)
        return grad


    def log_prior_transition_gradient(self, model_param):
        '''
        The gradient of the log transition function at the given prior

        Args:
            model_param - parameter of a specific model


        Returns:
            grad - numpy.ndarray - gradient
        '''
        trans_param = model_param[:, :, :-1]
        grad = np.zeros_like(model_param)
        for state in range(self.nState):
            for action in range(self.nAction):
                for next_state in range(self.nState):
                    if next_state == model_param.shape[0] - 1:
                        grad[state, action, :-1] += -1.0*(self.prior_param[state,action,next_state]-1) / (1.0 - np.sum(trans_param[state, action, :])+1e-9)  # fixed bug
                    else:
                        grad[state, action, next_state] += 1.0*(self.prior_param[state,action,next_state]-1) / (trans_param[state, action, next_state]+1e-9)
        return grad

    
    def log_prior_reward_gradient(self, model_param):
        '''
        The gradient of the log reward function at the given prior

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        for state in range(self.nState):
            for action in range(self.nAction):
                grad[state, action, -1] = (self.prior_param[state,action,self.nState]-1) / (model_param[state, action, -1]+1e-9) - \
                                          (self.prior_param[state,action,self.nState+1]-1) / (1 - model_param[state, action, -1]+1e-9)
        return grad

    def log_posterior_gradient(self, model_param):
        '''
        The gradient of the log posterior probability

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        jacobian = np.zeros_like(model_param)
        log_likely_gradient = np.zeros_like(model_param)
        for s in range(self.nState):
            for a in range(self.nAction):
                log_likely_gradient[s,a,-1] -= (self.history[:,:,-2].sum()/(model_param[s,a,-1]+1e-9)-self.history[:,:,-1].sum()/(1-model_param[s,a,-1]+1e-9))
                for sp in range(self.nState):
                    if sp == self.nState - 1:
                        log_likely_gradient[s, a, :-1] -= (self.history[s, a, sp]) * \
                                               (-1.0 / (1.0 - np.sum(model_param[s, a, :-1] + 1e-9)))  # fixed bug
                    else:
                        log_likely_gradient[s, a, sp] -= (self.history[s, a, sp]) * \
                                              (1.0 / (model_param[s, a, sp] + 1e-9))
        jacobian = log_likely_gradient - self.log_prior_transition_gradient(model_param) - self.log_prior_reward_gradient(model_param)
        return jacobian.reshape(-1)
    @classmethod
    def reward_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the reward distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''
        
        reward_param = model_param[:, :, -1]

        fisher_info = np.zeros_like(model_param)
        fisher_info[state, action, -1] = 1.0 / (reward_param * (1.0 - reward_param))
        fisher_info = np.diag(fisher_info)
        return fisher_info

    @classmethod
    def transition_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the transition distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''
        trans_param = model_param[:, :, :-1]

        nState, nAction, _ = model_param.shape
        fisher_info = np.zeros((nState, nAction, nState, nState))
        fisher_info[:, :, :-1, :-1] =  np.expand_dims(1.0 / (1.0 - np.sum(trans_param, axis=-1, keepdims=True)), axis=-1)
        for s in range(nState):
            for a in range(nAction):
                for sp in range(nState-1):
                    fisher_info[s, a, sp, sp] += 1.0 / trans_param[s, a, sp]
        fisher_info = block_diag(*fisher_info.reshape(-1, nState, nState))
        return fisher_info

    @classmethod
    def model_log_probability(cls, std_param, dist_param,ignore_normalfactor):
        '''
        Calculate the log probability of a particular model in a given distribution

        Args:
            std_param - standard parameter of a specific model
            dist_param - parameter of the distribution
            ignore_normalfactor - Normalizing factors in probability distributions are ignored or not

        '''
        if not np.all(std_param >= 0):
            return -np.inf

        # Calculate the log Beta function of the Dirichlet and Beta distribution
        if ignore_normalfactor==False:
            gammaln_prior_param = gammaln(dist_param)
            log_p = np.sum(gammaln(np.sum(dist_param[:, :, :-2], axis=-1))) + np.sum(gammaln(np.sum(dist_param[:, :, -2:], axis=-1))) - np.sum(gammaln_prior_param)
        else:
            log_p=0
        # Calculate the weighted log probability part
        log_std_param = np.log(std_param).reshape(-1)
        model_param_coeff = (dist_param - 1.0).reshape(-1)
        log_p += np.dot(model_param_coeff, log_std_param)


        return log_p

    @classmethod
    def model_log_likelihood(cls, std_param, history,ignore_normalfactor):
        '''
        Calculate the log probability of a particular model in given distribution

        Args:
            std_param - standard parameter of a specific model
            history - the observation history
            ignore_normalfactor - Normalizing factors in probability distributions are ignored or not
        '''
        if not np.all(std_param >= 0):
            return np.nan
        
        transition_history = history[:, :, :-2]
        reward_history = history[:, :, -2:]

        # Calculate the log Beta part
        if ignore_normalfactor==False:
            gammaln_history = gammaln(history)
            log_p = np.sum(gammaln(np.sum(transition_history, axis=-1))) + np.sum(gammaln(np.sum(reward_history, axis=-1))) - np.sum(gammaln_history)
        else:
            log_p=0
        # Calculate the weighted log probability part
        log_model_param = np.log(std_param).reshape(-1)
        model_param_coeff = history.reshape(-1)
        log_p += np.dot(model_param_coeff, log_model_param)

        return log_p


#-------------------------------------------------------------------------------
# Softmax Bayesian Tabular Model

class SoftmaxBayesianTabularModel(BayesianTabularModel):
    '''
    The softmax Bayesian tabular model assumes finite state and action spaces.
    The reward and transition of all state-action pairs are made independent,
    where the transition and reward function adopts softmax parameterization.
    Meanwhile, the reward is assumed to follow a Bernoulli distribution.

    The parameter is an S*A*(S+2) matrix, where the last two dimentionalities of 
    the last axis represent reward distribution.
    '''

    @classmethod
    def to_standard_param(self, model_param):
        '''
        Convert the model_param matrix into the standard parameters
        of the transition and reward distributions. The standard transition
        parameter is an S*A*S matrix specifying the full transition probability.
        The standard reward parameter is an S*A*2 matrix specifying the
        probability of receving rewards 1 and 0. These two matrices are
        concatenated along the last axis.

        Args:
            model_param - numpy.ndarray - parameter matrix of size S*A*(S+2)#fixed bug
        
        Returns:
            std_param - numpy.ndarray - parameter matrix of size S*A*(S+2)
        '''
        trans_param = model_param[:, :, :-2]
        reward_param = model_param[:, :, -2:]
        std_param = np.concatenate((softmax(trans_param, axis=-1),
                                    softmax(reward_param, axis=-1)), axis=-1)
        std_param = np.clip(std_param, 1e-9, 1.0 - 1e-9)
        return std_param

    def map_model(self):
        '''
        Calculate the MAP solution of the posterior

        Returns:
            map_model_param - numpy.ndarray - parameter of the MAP model
        '''
        return np.log(self.posterior_param - 1.0 + 1e-9)

    def posterior_sampling(self, size=None):
        '''
        Sample from the posterior distribution

        Args:
            size - int - sample size
        
        Returns:
            params - numpy.ndarray - samples are stored along the first axis
        '''
        n_sample = 1 if size == None else size
        samples = np.zeros((n_sample, self.nState, self.nAction, self.nState))

        trans_param = self.posterior_param[:, :, :-2]
        reward_param = self.posterior_param[:, :, -2:]
        for s in range(self.nState):
            for a in range(self.nAction):

                samples[:, s, a, :-2] = self.randomgenerator.dirichlet(trans_param, n_sample)
                samples[:, s, a, -2:] = self.randomgenerator.dirichlet(reward_param, n_sample)

        samples = np.log(samples)
        return samples[0] if size == None else samples
    
    def hpd_constraints(self, hpd_thres):
        '''
        Used in optimization. Specify the HPD region.

        Args:
            hpd_thres - float - the HPD threshold
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        posterior_param = self.posterior_param - 1.0
        posterior_param /= np.sum(posterior_param)

        def hpd_jacobian(model_param):
            '''
            Calculate the jacobian of the HPD function

            Args:
                model_param - parameter of a specific model
            
            Returns:
                jacobian - numpy.ndarray - jacobian of the HPD function
            '''
            std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
            jacobian = (1.0 - std_param) * posterior_param

            return jacobian.reshape(1, -1)

        return [{
                'type': 'ineq',
                'fun': lambda param : self.hpd_func(param.reshape(self.nState, self.nAction, -1)) - hpd_thres,
                'jac': hpd_jacobian
                },]
    
    def param_feasible_region(self):
        '''
        Used in optimization. Specify the feasible region of the model parameter.
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        bounds = Bounds(np.full((self.nState * self.nAction * (self.nState + 2)), - np.inf),
                        np.full((self.nState * self.nAction * (self.nState + 2)), np.inf))
                       
        constraints = []

        return bounds, constraints


    def sample_reward(self, model_param, state, action, size=None):
        '''
        Sample a reward given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None - number of reward samples
        
        Returns:
            reward - float - a single reward if `size=None` or a list of rewards if otherwise
        '''

        return self.randomgenerator.binomial(1, model_param[state, action, -2], size=size)
    

    def sample_next_state(self, model_param, state, action, size=None):
        '''
        Sample the next state given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None- number of next state samples
        
        Returns:
            next_state - a single state if `size=None` or a list of states if otherwise
        '''

        return self.randomgenerator.binomial(1, model_param[state, action, -2], size=size)

    @classmethod
    def to_tabular_MDP(cls, model_param):
        '''
        Convert a finite MDP to a deterministic reward tabular MDP.

        Args:
            model_param - parameter of a specific model
        
        Returns:
            tabularMDP - tuple - tuple of S*A*S transition matrix and S*A reward matrix
        '''

        trans_param = model_param[:, :, :-2]
        reward_param = model_param[:, :, -2:]

        trans = softmax(trans_param, axis=-1)
        rewards = softmax(reward_param, axis=-1)[:, :, 0]

        return (trans, rewards)

    @classmethod
    def mean_reward_gradient(cls, model_param, state, action):
        '''
        The gradient of the mean reward function at the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action

        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        grad[state, action, -2] = np.prod(softmax(model_param[state, action, -2:],axis=-1))
        grad[state, action, -1] = -grad[state, action, -2]
        return grad



    @classmethod
    def log_transition_gradient(cls, model_param, state, action, next_state):
        '''
        The gradient of the log transition function at the given state-action-next_state pair

        Args:
            model_param - parameter of a specific model
            state
            action
            next_state

        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        grad[state, action, :-2] = - softmax(model_param[state, action, :-2],axis=-1)
        grad[state, action, next_state] += 1.0
        return grad
    def value_gradient(self, model_param, occupancy_measure, value_weighted_occupancy_measure):
        '''
        The gradient of the value

        Args:
            model_param - parameter of a specific model
            occupancy_measure
            value_weighted_occupancy_measure
        Returns:
            grad - numpy.ndarray - gradient
        '''
        jacobian = np.zeros_like(model_param)
        for s in range(self.nState):
            for a in range(self.nAction):
                grad=np.prod(softmax(model_param[s, a, -2:],axis=-1))
                jacobian[s,a,-2] -= occupancy_measure[s, a] * grad
                jacobian[s, a, -1] -= occupancy_measure[s, a] * (-grad)
                for sp in range(self.nState):
                    jacobian[s,a,:-2]-=value_weighted_occupancy_measure[s, a, sp] * \
                                       (- softmax(model_param[s, a, :-2],axis=-1))
                    jacobian[s,a,sp]-=1.0
        return jacobian.reshape(-1)

    @classmethod
    def reward_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the reward distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''

        reward_param = model_param[:, :, -2:]

        nState, nAction, _ = model_param.shape
        fisher_info = np.zeros((nState, nAction, nState+2, nState+2))
        softmax_reward_param = softmax(reward_param, axis=-1)
        coeff = np.expand_dims(np.sum(softmax_reward_param, axis=-1, keepdims=True), axis=-1) - 2.0 
        fisher_info[:, :, -2:, -2:] = coeff * (np.expand_dims(softmax_reward_param, axis=-1) @ np.expand_dims(softmax_reward_param, axis=-2))
        for s in range(nState):
            for a in range(nAction):
                for i in range(2):
                    fisher_info[s, a, nState+i, nState+i] += softmax_reward_param[s, a, i]
        fisher_info = block_diag(*fisher_info.reshape(-1, nState+2, nState+2))
        return fisher_info

    @classmethod
    def transition_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the transition distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''

        trans_param = model_param[:, :, :-2]

        nState, nAction, _ = model_param.shape
        fisher_info = np.zeros((nState, nAction, nState+2, nState+2))
        softmax_trans_param = softmax(trans_param, axis=-1)
        coeff = np.expand_dims(np.sum(softmax_trans_param, axis=-1, keepdims=True), axis=-1) - 2.0 
        fisher_info[:, :, :-2, :-2] = coeff * (np.expand_dims(softmax_trans_param, axis=-1) @ np.expand_dims(softmax_trans_param, axis=-2))
        for s in range(nState):
            for a in range(nAction):
                for sp in range(nState):
                    fisher_info[s, a, sp, sp] += softmax_trans_param[s, a, sp]
        fisher_info = block_diag(*fisher_info.reshape(-1, nState+2, nState+2))
        return fisher_info

#-------------------------------------------------------------------------------
# Gaussian Reward Bayesian Tabular Model

class GaussianRewardBayesianTabularModel(BayesianTabularModel):
    '''
    The Gaussian-reward Bayesian tabular model assumes finite state and action
    spaces. The reward and transition of all state-action pairs are made independent.
    Meanwhile, the reward is assumed to follow a Gaussian distribution with
    a known precision 1.

    The parameter is an S*A*S matrix, where the last dimentionality of 
    the last axis represent the mean reward, and other dimensitionalities
    determine the transition probability to the first S-1 states.
    '''
    precision = 1.0

    def __init__(self, prior_param,batch_size=0,seed=0):
        '''
        Initialize a new Bayesian tabular model.
        The prior parameter is an S*A*(S+2) matrix, where the last two dimensionality
        of the last axis denote the Gaussian prior of the reward, and the others represent
        the Dirichlet prior of the transition.

        Args:
            prior_param - numpy.ndarray - parameter of the prior distribution

        Returns:
            a Bayesian model object
        '''
        super().__init__(prior_param,batch_size,seed)
        self.nState, self.nAction, _ = prior_param.shape
        self.posterior_param = np.copy(self.prior_param)
        self.history = np.zeros((self.nState, self.nAction, self.nState+3))
    
    @classmethod
    def to_standard_param(self, model_param):
        '''
        Convert the model_param matrix into the standard parameters
        of the transition and reward distributions. The standard transition
        parameter is an S*A*S matrix specifying the full transition probability.
        The standard reward parameter is an S*A*2 matrix specifying the
        mean and precision of the Gaussian distribution. These two matrices
        are concatenated along the last axis.

        Args:
            model_param - numpy.ndarray - parameter matrix of size S*A*S
        
        Returns:
            std_param - numpy.ndarray - parameter matrix of size S*A*(S+2)
        '''
        trans_param = model_param[:, :, :-1]
        reward_param = model_param[:, :, -1:]
        std_param = np.concatenate((trans_param, 1.0 - np.sum(trans_param, axis=-1, keepdims=True),
                                    reward_param, self.precision * np.ones_like(reward_param)), axis=-1)
        std_param[:, :, :-2] = np.clip(std_param[:, :, :-2], 1e-9, 1.0 - 1e-9)
        return std_param

    def update_history(self, episode):
        '''
        Add a new episode trajectory into the history.
        Update the posterior distribution if possible.

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        self.history_num+=1
        # state_aciton_h=[]
        # h=0
        for (s, a, r, sp, done) in episode:
            if not done:
                self.history[s, a, sp] += 1
            self.history[s, a, self.nState] += 1 # count
            self.history[s, a, self.nState + 1] += r # reward sum
            self.history[s, a, self.nState + 2] += r * r # reward square sum fixed bug
            #add state to state buffer
            # state_aciton_h.append((s,a,h))
            # h+=1

        # self.state_buffer[self.buffer_index]=state_aciton_h
        # self.buffer_index=(self.buffer_index+1)%len(self.state_buffer)
        # Transition memory and reward memory are used in random gradient descents
        if self.batch_size!=0:
            for (s, a, r, sp, done) in episode:
                if not done:
                    # add (s, a, sp) to Transition buffer
                    self.transition_memory.append((s, a, sp))
                # add (s, a, r) to reward buffer
                self.reward_memory.append((s, a, r))


        self.update_posterior(episode)

    def update_posterior(self, episode):
        '''
        Update the posterior using a new episodic trajectory

        Args:
            episode - list - list of (s, a, r, sp, done) tuple
        '''
        for (s, a, r, sp, done) in episode:
            if not done:
                self.posterior_param[s, a, sp] += 1

            # Update precision
            self.posterior_param[s, a, self.nState + 1] += self.precision
            # Update mean
            self.posterior_param[s, a, self.nState] += self.precision * (r - self.posterior_param[s, a, self.nState]) / self.posterior_param[s, a, self.nState + 1]



    def map_model(self):
        '''
        Calculate the MAP solution of the posterior

        Returns:
            map_model_param - numpy.ndarray - parameter of the MAP model
        '''
        trans_param = np.copy(self.posterior_param[:, :, :-2]) - 1.0 + 1e-9
        mean_reward = np.copy(self.posterior_param[:, :, -2:-1])

        trans_normalizer = np.sum(trans_param, axis=-1, keepdims=True)
        return np.concatenate((trans_param[:, :, :-1] / trans_normalizer, mean_reward), axis=-1)

    def posterior_sampling(self, size=None):
        '''
        Sample from the posterior distribution

        Args:
            size - int - sample size
        
        Returns:
            params - numpy.ndarray - samples are stored along the first axis
        '''
        n_sample = 1 if size == None else size
        samples = np.zeros((n_sample, self.nState, self.nAction, self.nState))

        trans_param = self.posterior_param[:, :, :-2]
        mean_reward = self.posterior_param[:, :, -2]
        reward_scale = 1.0 / np.sqrt(self.posterior_param[:, :, -1] + 1e-9)

        for s in range(self.nState):
            for a in range(self.nAction):

                samples[:, s, a, :] = self.randomgenerator.dirichlet(trans_param[s, a], n_sample)
                samples[:, s, a, -1] = self.randomgenerator.normal(mean_reward[s, a], reward_scale[s, a], n_sample)

        return samples[0] if size == None else samples
    
    def hpd_func(self, model_param):
        '''
        The hpd function is an incresing function of the posterior probability

        Args:
            model_param - parameter of a specific model
        
        Returns:
            hpd_value - float - the value of the HPD function at the specific model parameter
        '''

        std_param = self.to_standard_param(model_param)
        trans_param = std_param[:, :, :-2]
        mean_reward = std_param[:, :, -2]

        posterior_dirichlet_param = self.posterior_param[:, :, :-2] - 1.0
        posterior_gaussian_mean = self.posterior_param[:, :, -2]
        posteriro_gaussian_scale = self.posterior_param[:, :, -1]

        hpd_value = np.sum(np.log(trans_param) * posterior_dirichlet_param) - \
                    np.sum(posteriro_gaussian_scale * np.square(mean_reward - posterior_gaussian_mean)) / 2.0
        hpd_value /= np.sum(posterior_dirichlet_param) + np.sum(posteriro_gaussian_scale)
        return hpd_value
    
    def hpd_constraints(self, hpd_thres):
        '''
        Used in optimization. Specify the HPD region.

        Args:
            hpd_thres - float - the HPD threshold
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        posterior_dirichlet_param = self.posterior_param[:, :, :-2] - 1.0
        posterior_gaussian_mean = self.posterior_param[:, :, -2]
        posteriro_gaussian_scale = self.posterior_param[:, :, -1]

        normalizing_factor = np.sum(posterior_dirichlet_param) + np.sum(posteriro_gaussian_scale)
        posterior_dirichlet_param /= normalizing_factor
        posteriro_gaussian_scale /= normalizing_factor

        def hpd_jacobian(model_param):
            '''
            Calculate the jacobian of the HPD function

            Args:
                model_param - parameter of a specific model
            
            Returns:
                jacobian - numpy.ndarray - jacobian of the HPD function
            '''
            std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
            trans_param = std_param[:, :, :-2]
            mean_reward = std_param[:, :, -2]

            inverse_trans_param = 1.0 / trans_param
            trans_gradient = inverse_trans_param * posterior_dirichlet_param
            reward_gradient = posteriro_gaussian_scale * (mean_reward - posterior_gaussian_mean)
            
            jacobian = np.zeros((self.nState, self.nAction, self.nState))
            jacobian[:, :, :-1] = trans_gradient[:, :, :-1] - trans_gradient[:, :, -1:]
            jacobian[:, :, -1] = reward_gradient

            return jacobian.reshape(1, -1)

        return [{
                'type': 'ineq',
                'fun': lambda param : self.hpd_func(param.reshape(self.nState, self.nAction, -1)) - hpd_thres,
                'jac': hpd_jacobian
                },]
    
    def param_feasible_region(self):
        '''
        Used in optimization. Specify the feasible region of the model parameter.
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        bounds, constraints = super().param_feasible_region()
        for s in range(self.nState):
            for a in range(self.nAction):
                bounds.lb[s*self.nState*self.nAction+a*self.nState+self.nState-1] = - np.inf
                bounds.ub[s*self.nState*self.nAction+a*self.nState+self.nState-1] = np.inf


        return bounds, constraints
    

    def sample_reward(self, model_param, state, action, size=None):
        '''
        Sample a reward given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None - number of reward samples
        
        Returns:
            reward - float - a single reward if `size=None` or a list of rewards if otherwise
        '''

        return self.randomgenerator.normal(model_param[state, action, -1], self.precision, size=size)
    
    @classmethod
    def reward_fisher_information(cls, model_param, state, action):
        '''
        The fisher information of the reward distribution of the given state-action pair

        Args:
            model_param - parameter of a specific model
            state
            action
        
        Returns:
            fisher_info - numpy.ndarray - the fisher information
        '''
        fisher_info = np.zeros_like(model_param)
        fisher_info[state, action, -1] = cls.precision * (1.0 - cls.precision * model_param[state, action, -1] * model_param[state, action, -1])
        fisher_info = np.diag(fisher_info)
        return fisher_info


    @classmethod
    def model_log_probability(cls, std_param, dist_param,ignore_normalfactor=True):
        '''
        Calculate the log probability of a particular model in a given distribution

        Args:
            std_param - standard parameter of a specific model
            dist_param - parameter of the distribution
            ignore_normalfactor - Normalizing factors in probability distributions are ignored or not

        '''
        trans_param = std_param[:, :, :-2]
        mean_reward = std_param[:, :, -2]

        if not np.all(trans_param >= 0):
            return -np.inf

        posterior_dirichlet_param = dist_param[:, :, :-2]
        posterior_gaussian_mean = dist_param[:, :, -2]
        posteriro_gaussian_scale = dist_param[:, :, -1]

        # Calculate the log Beta function of the Dirichlet density
        if ignore_normalfactor==False:
            log_p = np.sum(gammaln(np.sum(posterior_dirichlet_param, axis=-1))) - np.sum(gammaln(posterior_dirichlet_param))
        else:
            log_p=0
        # Calculate the weighted log probability part of the Dirichlet density
        log_trans_param = np.log(trans_param).reshape(-1)
        log_trans_coeff = (posterior_dirichlet_param - 1.0).reshape(-1)
        log_p += np.dot(log_trans_coeff, log_trans_param)

        # Calculate the coefficient part of the Gaussian density
        if ignore_normalfactor==False:
            log_p += np.sum(np.log(posteriro_gaussian_scale / (2 * np.pi)) / 2.0)

        # Calculate the square part of the Gaussian density
        log_p -= np.sum(posteriro_gaussian_scale * np.square(mean_reward - posterior_gaussian_mean)) / 2.0

        return log_p

    @classmethod
    def model_log_likelihood(cls, std_param, history,ignore_normalfactor):
        '''
        Calculate the log probability of a particular model in given distribution

        Args:
            std_param - standard parameter of a specific model
            history - the observation history
            ignore_normalfactor - Normalizing factors in probability distributions are ignored or not
        '''
        trans_param = std_param[:, :, :-2]
        mean_reward = std_param[:, :, -2]
        if not np.all(trans_param >= 0):
            return np.nan


        reward_scale = std_param[:, :, -1]

        transion_history = history[:, :, :-3]
        count_history = history[:, :, -3]
        reward_sum_history = history[:, :, -2]
        reward_square_sum_history = history[:, :, -1]

        # Calculate the log Beta part
        if ignore_normalfactor==False:
            log_p = np.sum(gammaln(count_history)) - np.sum(gammaln(transion_history))
        else:
            log_p=0

        # Calculate the weighted log probability part
        log_trans_param = np.log(trans_param).reshape(-1)
        log_trans_coeff = transion_history.reshape(-1)
        log_p += np.dot(log_trans_coeff, log_trans_param)

        # Calcuate coefficient part of the Gaussian density
        if ignore_normalfactor==False:
            log_p += np.sum(count_history * np.log(reward_scale / (2. * np.pi))) / 2.0

        # Calculate the square part of the Gaussian density
        log_p -= np.sum((reward_square_sum_history - 2.0 * reward_sum_history * mean_reward + np.square(mean_reward)) * reward_scale) / 2.0

        return log_p
    
    def log_reward_gradient(self, model_param, state, action):
        '''
        The gradient of the log reward function at the given state-action pair
        The reward is assumed to follow a Gaussian distribution

        Args:
            model_param - parameter of a specific model
            state
            action
        Returns:
            grad - numpy.ndarray - gradient
        '''
        reward_history=self.history[:,:,-3:]
        std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
        reward_scale = std_param[:, :, -1]
        grad=np.zeros_like(model_param)
        grad[state,action,-1]=(reward_history[state,action,1]-reward_history[state,action,0]*model_param[state,action,-1])*reward_scale[state,action]#\sum(r_i-u_{s,a})
        return grad

    def log_posterior_gradient(self, model_param):
        '''
        The gradient of the log posterior probability

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        jacobian = np.zeros_like(model_param)
        log_likely_gradient = np.zeros_like(model_param)
        reward_history = self.history[:, :, -3:]
        std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
        reward_scale = std_param[:, :, -1]
        for s in range(self.nState):
            for a in range(self.nAction):
                log_likely_gradient[s,a,-1] -= (reward_history[s,a,1]-reward_history[s,a,0]*model_param[s,a,-1])*reward_scale[s,a]#\sum(r_i-u_{s,a})
                for sp in range(self.nState):
                    if sp == self.nState - 1:
                        log_likely_gradient[s, a, :-1] -= (self.history[s, a, sp]) * \
                                               (-1.0 / (1.0 - np.sum(model_param[s, a, :-1] + 1e-9)))  # fixed bug
                    else:
                        log_likely_gradient[s, a, sp] -= (self.history[s, a, sp]) * \
                                              (1.0 / (model_param[s, a, sp] + 1e-9))
        jacobian = log_likely_gradient - self.log_prior_transition_gradient(model_param) - self.log_prior_reward_gradient(model_param)
        return jacobian.reshape(-1)
    def log_prior_reward_gradient(self, model_param):
        '''
        The gradient of the log reward function at the given prior

        Args:
            model_param - parameter of a specific model

        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        for state in range(self.nState):
            for action in range(self.nAction):
                grad[state,action,-1]=self.prior_param[state,action,self.nState-1+1]-model_param[state,action,-1]
        return grad

#-------------------------------------------------------------------------------
# Gaussian Reward Softmax Bayesian Tabular Model

class GaussianRewardSoftmaxBayesianTabularModel(GaussianRewardBayesianTabularModel):
    '''
    The Gaussian-reward Bayesian tabular model assumes finite state and action
    spaces. The reward and transition of all state-action pairs are made independent.
    The transition function adopts softmax parameterization. Meanwhile, the reward
    is assumed to follow a Gaussian distribution with a known precision 1.

    The parameter is an S*A*(S+1) matrix, where the last dimentionality of 
    the last axis represent the mean reward, and other dimensitionalities
    determine the transition probability.
    '''
    precision = 1.0

    def __init__(self, prior_param,batch_size=0,seed=0):
        '''
        Initialize a new Bayesian tabular model.
        The prior parameter is an S*A*(S+2) matrix, where the last two dimensionality
        of the last axis denote the Gaussian prior of the reward, and the others represent
        the Dirichlet prior of the transition.

        Args:
            prior_param - numpy.ndarray - parameter of the prior distribution
            batch_size - The size of the mini batch data.

        Returns:
            a Bayesian model object
        '''
        super().__init__(prior_param,batch_size,seed)
        self.nState, self.nAction, _ = prior_param.shape
        self.posterior_param =np.copy(self.prior_param)
        self.history = np.zeros((self.nState, self.nAction, self.nState+3))
        self.transition = []
        for s in range(self.nState):
            for a in range(self.nAction):
                for sp in range(self.nState):
                    self.transition.append((s, a, sp))
        #Used to store the history of batch data
        self.batch_history = np.zeros((self.nState, self.nAction, self.nState + 3))




    
    @classmethod
    def to_standard_param(self, model_param):
        '''
        Convert the model_param matrix into the standard parameters
        of the transition and reward distributions. The standard transition
        parameter is an S*A*S matrix specifying the full transition probability.
        The standard reward parameter is an S*A*2 matrix specifying the
        mean and precision of the Gaussian distribution. These two matrices
        are concatenated along the last axis.

        Args:
            model_param - numpy.ndarray - parameter matrix of size S*A*S
        
        Returns:
            std_param - numpy.ndarray - parameter matrix of size S*A*(S+2)
        '''
        trans_param = model_param[:, :, :-1]
        reward_param = model_param[:, :, -1:]
        std_param = np.concatenate((softmax(trans_param, axis=-1),
                                    reward_param, self.precision * np.ones_like(reward_param)), axis=-1)
        std_param[:, :, :-2] = np.clip(std_param[:, :, :-2], 1e-9, 1.0 - 1e-9)
        return std_param

    def map_model(self):
        '''
        Calculate the MAP solution of the posterior

        Returns:
            map_model_param - numpy.ndarray - parameter of the MAP model
        '''
        trans_param = np.log(self.posterior_param[:, :, :-2] - 1.0 +1e-20)
        mean_reward = self.posterior_param[:, :, -2:-1]

        return np.concatenate((trans_param, mean_reward), axis=-1)
    def posterior_mean_model(self):
        '''
        Calculate the posterior mean and model

        Returns:
            posterior_mean_model_param - numpy.ndarray - parameter of the posterior mean model


        '''
        trans_param = np.log(self.posterior_param[:, :, :-2])
        mean_reward = self.posterior_param[:, :, -2:-1]
        return np.concatenate((trans_param, mean_reward), axis=-1)
    def posterior_sampling(self, size=None):
        '''
        Sample from the posterior distribution

        Args:
            size - int - sample size
        
        Returns:
            params - numpy.ndarray - samples are stored along the first axis
        '''
        n_sample = 1 if size == None else size
        samples = np.zeros((n_sample, self.nState, self.nAction, self.nState+1))

        trans_param = self.posterior_param[:, :, :-2]
        mean_reward = self.posterior_param[:, :, -2]
        reward_scale = 1.0 / np.sqrt(self.posterior_param[:, :, -1])

        for s in range(self.nState):
            for a in range(self.nAction):

                samples[:, s, a, :-1] = self.randomgenerator.dirichlet(trans_param[s, a], n_sample)
                samples[:, s, a, -1] = self.randomgenerator.normal(mean_reward[s, a], reward_scale[s, a], n_sample)

        samples[:, :, :, :-1] = np.log(samples[:, :, :, :-1])

        return samples[0] if size == None else samples
    
    def hpd_constraints(self, hpd_thres):
        '''
        Used in optimization. Specify the HPD region.

        Args:
            hpd_thres - float - the HPD threshold
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        posterior_dirichlet_param = self.posterior_param[:, :, :-2] - 1.0
        posterior_gaussian_mean = self.posterior_param[:, :, -2]
        posteriro_gaussian_scale = self.posterior_param[:, :, -1]

        normalizing_factor = np.sum(posterior_dirichlet_param) + np.sum(posteriro_gaussian_scale)
        posterior_dirichlet_param /= normalizing_factor
        posteriro_gaussian_scale /= normalizing_factor

        def hpd_jacobian(model_param):
            '''
            Calculate the jacobian of the HPD function

            Args:
                model_param - parameter of a specific model
            
            Returns:
                jacobian - numpy.ndarray - jacobian of the HPD function
            '''
            std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
            trans_param = std_param[:, :, :-2]
            mean_reward = std_param[:, :, -2]

            jacobian = np.zeros((self.nState, self.nActionn, self.nState + 1))
            jacobian[:, :, :-1] = (1.0 - trans_param) * posterior_dirichlet_param
            jacobian[:, :, -1] = posteriro_gaussian_scale * (mean_reward - posterior_gaussian_mean)

            return jacobian.reshape(1, -1)

        return [{
                'type': 'ineq',
                'fun': lambda param : self.hpd_func(param.reshape(self.nState, self.nAction, -1)) - hpd_thres,
                'jac': hpd_jacobian
                },]
    
    def param_feasible_region(self):
        '''
        Used in optimization. Specify the feasible region of the model parameter.
        
        Returns:
            constraints - list of constraint used for optimization
        '''
        bounds = Bounds(np.full((self.nState * self.nAction * (self.nState + 1)), - np.inf),
                        np.full((self.nState * self.nAction * (self.nState + 1)), np.inf))
                       
        constraints = []

        return bounds, constraints


    def sample_next_state(self, model_param, state, action, size=None):
        '''
        Sample the next state given the current state and action

        Args:
            model_param - parameter of a specific model
            state
            action
            size - int or None- number of next state samples
        
        Returns:
            next_state - a single state if `size=None` or a list of states if otherwise
        '''

        return self.randomgenerator.choice(model_param.shape[0], size=size, p=softmax(model_param[state, action, :-1],axis=-1))
    @classmethod
    def to_tabular_MDP(cls, model_param):
        '''
        Convert a finite MDP to a deterministic reward tabular MDP.

        Args:
            model_param - parameter of a specific model
        
        Returns:
            tabularMDP - tuple - tuple of S*A*S transition matrix and S*A reward matrix
        '''

        trans_param = softmax(model_param[:, :, :-1], axis=-1)
        reward_param = model_param[:, :, -1]

        return (trans_param, reward_param)

    @classmethod
    def log_transition_gradient(cls, model_param, state, action, next_state):
        '''
        The gradient of the log transition function at the given state-action-next_state pair

        Args:
            model_param - parameter of a specific model
            state
            action
            next_state
        
        Returns:
            grad - numpy.ndarray - gradient
        '''
        grad = np.zeros_like(model_param)
        grad[state, action, :-1] = - softmax(model_param[state, action, :-1],axis=-1)
        grad[state, action, next_state] += 1.0
        return grad

    def sample_batchdata(self,eplen):
        """
        Creates a list of random minibatches from the history data
        Args:
            eplen: the length of one episode
        Returns:
            mini_batches: List of the mini_batch.
        """
        mini_batches=[]
        num_complete_minibatch = self.history_num//self.batch_size  # num of mini batches

        # for i in range(num_complete_minibatch):
        #     # Step 1:Sample the transition data
        #     trans_data=np.array(self.transition_memory)[self.randomgenerator.integers(low=0,high=len(self.transition_memory),size=self.batch_size*(eplen-1))]
        #     # Step 2:Sample the reward data
        #     reward_data=np.array(self.reward_memory)[self.randomgenerator.integers(low=0,high=len(self.reward_memory),size=self.batch_size*eplen)]
        #
        #     mini_batches.append((trans_data,reward_data,self.batch_size))
        #
        # # Handling the end case
        # last_len=self.history_num-num_complete_minibatch*self.batch_size
        # if last_len!=0:
        #     trans_data = np.array(self.transition_memory)[self.randomgenerator.integers(low=0, high=len(self.transition_memory), size=last_len * (eplen - 1)).tolist()]
        #
        #     reward_data = np.array(self.reward_memory)[self.randomgenerator.integers(low=0, high=len(self.reward_memory), size=last_len * eplen).tolist()]
        #
        #     mini_batches.append((trans_data,reward_data,last_len))
        last_len=self.history_num-self.batch_size
        if last_len>0:
            # s=time.time()
            # t=self.randomgenerator.integers(low=0, high=len(self.transition_memory), size=self.batch_size * (eplen - 1))
            # print("1",time.time()-s)
            # s = time.time()
            # trans_data = self.randomgenerator.choice(a=self.transition_memory, size=self.batch_size * (eplen-1))
            # print("2", time.time() - s)
            s = time.time()
            p=self.history[:,:,:-3].reshape(-1)/np.sum(self.history[:,:,:-3])
            trans_data=self.randomgenerator.choice(a=self.transition,size=self.batch_size * (eplen-1),p=p)
            #print("4", time.time() - s)
            # Step 2:Sample the reward data
            s = time.time()
            reward_data = self.randomgenerator.choice(a=self.reward_memory, size=self.batch_size * eplen)
            #print("3", time.time() - s)
            # reward_data = np.array(self.reward_memory)[self.randomgenerator.integers(low=0, high=len(self.reward_memory), size=self.batch_size * eplen)]

            mini_batches.append((trans_data, reward_data, self.batch_size))
        elif self.history_num>0:
            mini_batches.append((self.transition_memory, self.reward_memory, self.history_num))
        return mini_batches

    def update_batch(self,batch_data):
        """
            Update the reward and transition history of the mini batch data
        """
        #reset
        self.batch_history=np.zeros((self.nState, self.nAction, self.nState+3))
        #Get the number of trajectory in the batch data
        self.batch_len=batch_data[2]

        #update the batch history and the batch posterior param

        for (s, a, sp) in batch_data[0]:
            self.batch_history[s, a, sp] += 1
        for (s, a, r) in batch_data[1]:
            self.batch_history[int(s), int(a), self.nState+0] += 1  # count
            self.batch_history[int(s), int(a), self.nState+1] += r  # reward sum
            self.batch_history[int(s), int(a), self.nState+2] += r * r  # reward square sum fixed bug


    def value_gradient(self, model_param, occupancy_measure, value_weighted_occupancy_measure):
        '''
        The gradient of the value

        Args:
            model_param - parameter of a specific model
            occupancy_measure
            value_weighted_occupancy_measure
        Returns:
            grad - numpy.ndarray - gradient
        '''
        jacobian = np.zeros_like(model_param)
        trans=softmax(model_param[:,:, :-1],axis=-1)

        jacobian[:, :, -1] -= occupancy_measure * 1.0
        jacobian[:,:,:-1]=trans * np.sum(value_weighted_occupancy_measure,axis=-1,keepdims=True)-value_weighted_occupancy_measure

        return jacobian.reshape(-1)
    def log_posterior_gradient(self, model_param,mini_batch=False):
        '''
        The gradient of the log posterior probability

        Args:
            model_param - parameter of a specific model
            minibatch - str -  Calculate the gradient under the all trajectories or under the mini_batch data sampled from the all trajectories

        Returns:
            grad - numpy.ndarray - gradient
        '''
        gradient = np.zeros_like(model_param)
        std_param = self.to_standard_param(model_param.reshape(self.nState, self.nAction, -1))
        reward_scale = std_param[:, :, -1]
        trans = softmax(model_param[:, :, :-1], axis=-1)
        if mini_batch==False:
            #the reward history and transition pseudocount of the whole history
            reward_history = self.history[:, :, -3:]
            trans_pseudocount=self.posterior_param[:, :, :self.nState]-1

            gradient[:, :, -1] = -(reward_history[:, :, 1]+self.prior_param[:,:, self.nState - 1 + 1] - (reward_history[:, :, 0]+1) * model_param[:, :, -1]) * reward_scale[:, :]
            gradient[:,:,:-1]=trans * np.sum(trans_pseudocount,axis=-1,keepdims=True)-trans_pseudocount
        else:
            #prior
            gradient[:,:,-1] = -(self.prior_param[:,:, self.nState - 1 + 1] - model_param[:, :, -1]) * reward_scale[:, :]
            prior_pseudocount = self.prior_param[:,:,:self.nState]-1
            gradient[:, :, :-1] = trans * np.sum(prior_pseudocount, axis=-1, keepdims=True) - prior_pseudocount
            #likelyhood
            if self.history_num>0:
                gradient[:, :, -1] += -(self.batch_history[:, :, self.nState+1] - (self.batch_history[:, :,self.nState+ 0]) * model_param[:, :, -1]) * reward_scale[:, :]*self.history_num/self.batch_len
                likely_pseudocount = self.batch_history[:, :, :self.nState] - 1
                gradient[:, :, :-1] += (trans * np.sum(likely_pseudocount, axis=-1, keepdims=True) - likely_pseudocount)*self.history_num/self.batch_len
        return gradient.reshape(-1)
    # def get_sample_distribution(self,epLen):
    #     #sd=np.zeros(self.nState)
    #     sad=np.zeros((self.nState,self.nAction))
    #     #get initial state distribution
    #     if len(self.state_buffer)==0:
    #         sad+=1/(self.nState*self.nAction)
    #         return sad
    #     for state_action_h_list in self.state_buffer:
    #         for state_action_h in state_action_h_list:
    #             sad[state_action_h[0],state_action_h[1]]+=1
    #
    #     sad/=np.sum(sad)
    #     return sad
    #
    # def get_model_inner_distribution(self,model_inner_distribution,epidx):
    #     #Returns the average distribution of the first n distributions
    #     if epidx<=500:
    #         dist=np.mean(model_inner_distribution[:epidx-1],axis=0)
    #     else:
    #         dist = np.mean(model_inner_distribution, axis=0)
    #     return dist





