'''
Finite horizon tabular agents.

This is a collection of some of the classic benchmark algorithms for efficient
reinforcement learning in a tabular MDP with little/no prior knowledge.
We provide implementations of:

- PSRL
- Gaussian PSRL
- UCBVI
- BEB
- BOLT
- UCRL2
- Epsilon-greedy
'''
import copy
from scipy.special import gammaln, softmax
import random
import time

import numpy as np
import cvxpy as cp
import scipy.optimize.linesearch
from scipy.stats import gmean
from scipy.optimize import minimize
import math
from scipy.optimize import SR1
from scipy.optimize import BFGS
from .agent import *


class FiniteHorizonFiniteMDPAgent(FiniteHorizonMDPAgent):
    '''
    A Finite Horizon Finite MDP Bayesian learner.

    Child agents will mainly implement:
        update_policy

    '''

    def __init__(self, bayesian_model, epLen,gamma, seed=0, **kwargs):
        '''
        Args:
            epLen - episode length
            bayesian_model - Bayesian model
            seed - the seed used to generate the random generator

        Returns:
            a Bayesian learner, to be inherited from
        '''

        # Instantiate the Bayes learner
        self.nState = bayesian_model.nState
        self.nAction = bayesian_model.nAction
        self.logp_gap = 0
        self.policy_matrix = np.ones((self.nState, epLen, self.nAction))  # uniform distribution
        self.policy_entropy = softmax(self.policy_matrix, axis=-1)
        self.model_inner_dist = 1 / (self.nState * self.nAction) * np.ones((self.nState, self.nAction))
        self.gamma = gamma

        super().__init__(bayesian_model, epLen, seed)

    def policy(self, state):
        #deterministic policy
        return np.argmax(self.qVals[state, self.time_period])

    def softmax_policy_evaluation(self, model_param, policy):
        '''
           Compute the qVals values for given model parameters and policy matrix
           policy matrix - S*H*A

           Args:
               model_param - numpy.ndarray - the parameter of a specific model

           Returns:
               qVals - an S*H*A tensor of qVals values
               vVals - an H*S matrix of optimal values
       '''
        P, R = self.bayesian_model.to_tabular_MDP(model_param)

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))
        for s in range(self.nState):
            vVals[self.epLen - 1][s] = np.dot(qVals[s, self.epLen - 1, :], policy[s, self.epLen - 1, :])

        for h in reversed(range(self.epLen - 1)):
            qVals[:, h, :] += P @ vVals[h + 1]
            for s in range(self.nState):
                vVals[h][s] = np.dot(qVals[s, h, :], policy[s, h, :])
        return qVals, vVals

    def value_iteration(self, model_param):
        '''
        Compute the qVals values for given model parameters

        Args:
            model_param - numpy.ndarray - the parameter of a specific model

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''
        P, R = self.bayesian_model.to_tabular_MDP(model_param)

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        vVals[self.epLen - 1] = np.max(qVals[:, self.epLen - 1, :], axis=-1)

        for h in reversed(range(self.epLen - 1)):
            qVals[:, h, :] += P @ vVals[h + 1]
            vVals[h] = np.max(qVals[:, h, :], axis=-1)

        return qVals, vVals
    # def model_policy(self,model_param):
    #     model_param=model_param.reshape(self.nState,self.nAction,-1)
    #     qVals, vVals=self.value_iteration(model_param)
    #     return np.argmax(qVals,axis=-1)


class FiniteHorizonTabularMDPAgent(FiniteHorizonFiniteMDPAgent):
    '''
    Simple tabular Bayesian learner from Tabula Rasa.

    Child agents will mainly implement:
        update_policy
    '''

    def extended_value_iteration(self, model_param, P_slack, R_slack):
        '''
        Compute the qVals values for a given set of model parameter by extended value
        iteration. It should be noted that extended value iteration works only in
        discounted MDP.

        Args:
            model_param - numpy.ndarray - the parameter of a specific model
            P_slack - transition slackness of size S*A
            R_slack - reward slackness of size S*A

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''

        P, R = self.bayesian_model.to_tabular_MDP(model_param)
        R += R_slack
        P_slack = np.repeat(P_slack[np.newaxis] * 0.5, self.epLen, axis=0)

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        vVals[self.epLen - 1] = np.max(qVals[:, self.epLen - 1, :], axis=-1)

        remaining_prob = np.ones((self.epLen, self.nState, self.nAction))

        for h in reversed(range(self.epLen - 1)):
            pInd = np.argsort(vVals[h + 1])
            optS = pInd[-1]

            opts_prob = np.minimum(remaining_prob[h], P[:, :, optS] + P_slack[h])
            remaining_prob[h] -= opts_prob
            qVals[:, h, :] += opts_prob * vVals[h + 1, optS]

            for optS in reversed(pInd[:-1]):
                if not np.any(remaining_prob[h] > 0.0):
                    break

                opts_prob = np.minimum(remaining_prob[h], P[:, :, optS])
                remaining_prob[h] -= opts_prob
                qVals[:, h, :] += opts_prob * vVals[h + 1, optS]

            vVals[h] = np.max(qVals[:, h, :], axis=-1)

        return qVals, vVals


# -----------------------------------------------------------------------------
# Finite Bayesian Optimistic Optimization
# -----------------------------------------------------------------------------
class FiniteBOO(FiniteHorizonFiniteMDPAgent):
    '''
    Balance Finite Bayesian Optimistic Optimization for finite-horizon finite MDPs.
    '''
    def __init__(self, bayesian_model, epLen, scaling=1.0, v1=0.5,v2=0,l1=0.5,l2=-1, zeta=1,gamma=0.998, seed=0, **kwargs):
        '''
        As per the tabular learner, but added tunable scaling.

        Args:
            delta - double - level of confidence
            zeta - double - control the scale of the entropy term
            seed - the seed used to generate the random generator
        '''
        self.scaling = scaling
        self.v1=v1
        self.v2 = v2
        self.l1 = l1
        self.l2 = l2

        self.zeta = zeta
        self.last_opt = bayesian_model.map_model()
        # self.model_inner_dist=[None]*500
        # self.model_inner_dist_index=0

        super().__init__(bayesian_model, epLen,gamma, seed)

    def policy(self, state):
        # return np.argmax(self.qVals[state, self.time_period])#deterministic policy
        return self.randomgenerator.choice(self.nAction, p=self.policy_entropy[state, self.time_period, :])  # Nondeterministic policy
    # def get_state_buffer(self,state_buffer):
    #     #set state buffer
    #     self.bayesian_model.state_buffer=state_buffer
    #     #self.model_inner_dist=self.bayesian_model.get_sample_distribution(self.epLen)


    def update_policy(self):
        '''
        Solve for a Bayesian optimistic solution (log P + V)

        Works in place with no arguments.
        '''

        def get_entropt_v(model_param, zeta):
            """

            Args:
                model_param: SxAx(S+1) MDP model
                zeta: control the scale of the entropy term


            Returns:
                the optimal soft Q fuction and V function

            """
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            P, R = self.bayesian_model.to_tabular_MDP(model_param)
            qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
            vVals = np.zeros((self.epLen, self.nState))

            for s in range(self.nState):
                vVals[self.epLen - 1][s] = zeta * scipy.special.logsumexp(qVals[s, self.epLen - 1, :] / zeta)

            for h in reversed(range(self.epLen - 1)):
                qVals[:, h, :] += P @ vVals[h + 1]
                vVals[h] = zeta * scipy.special.logsumexp(qVals[:, h, :] / zeta, axis=-1)
            return qVals, vVals

        def optimal_object(model_param, zeta,mini_batch=False):
            """

            Args:
                model_param:The parameter of MDP model, which is an S*A*(S+2) matrix
                zeta:
                mini_batch - bool -  mini_batch gradient descent or not
            Returns:
                the BOO object
            """
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            qVals, vVals = get_entropt_v(model_param, zeta)

            log_posterior = self.bayesian_model.log_posterior_probability(
                    model_param.reshape(self.nState, self.nAction, -1),mini_batch=mini_batch)

            soft_value_item = 0#vVals[0, 0]
            lambda_k = self.scaling * self.epIdx ** (- self.v1)*np.log(self.epIdx+1)**(-self.v2)
            #value_variation=self.epIdx ** ( self.l1)*np.log(self.epIdx+1)**(self.l2)
            mean_reward_item = np.mean(qVals[:, self.epLen - 1, :]) * self.epLen * np.log(1 + self.epIdx)

            #occ, _ = get_occupancy_measure(model_param, zeta)
            occ_reward = np.sum(self.model_inner_dist * qVals[:, 0, :])
            return (-soft_value_item -occ_reward- lambda_k * (mean_reward_item + log_posterior))

        def get_occupancy_measure(model_param, zeta):
            """

            Args:
                model_param: The parameter of MDP model, which is an S*A*(S+2) matrix
                zeta: Weight of the entropy item
                policy: The policy which used to calculate occupancy measure and value weighted occupancy measure

            Returns:
                occupancy_measure:
                value_weighted_occupancy_measure:
            """
            qVals, vVals = get_entropt_v(model_param, zeta)
            P, R = self.bayesian_model.to_tabular_MDP(model_param)
            occupancy_measure = np.zeros((self.nState, self.nAction))
            value_weighted_occupancy_measure = np.zeros((self.nState, self.nAction, self.nState))
            _, policy = get_optimal_policy(model_param, zeta)
            # initial state distribution
            current_s_dist = np.zeros(self.nState)
            extral_sa_dist = self.model_inner_dist#self.bayesian_model.get_model_inner_distribution(self.model_inner_dist,self.epIdx)
            #current_s_dist[0]=1.0
            #add uniform distribution
            #current_s_dist[0] = 1.0+1/(self.nState*self.epLen)*np.log(self.epIdx+1)/np.sqrt(self.epIdx)

            for h in range(self.epLen):

                a = policy[:, h, :]
                #normal
                if h==0:
                    current_sa_dist=extral_sa_dist
                else:
                    current_sa_dist = current_s_dist[:, np.newaxis] *a
                #current_sa_dist = current_s_dist[:, np.newaxis] * a+1/self.epLen*extral_sa_dist[:,h,:]

                #epsilon-greedy,execute a uniform policy with a small probability
                # mix_policy=(a*(1-0.1)+0.1*np.ones((self.nState,self.nAction))*1/self.nAction)
                # current_sa_dist = current_s_dist[:, np.newaxis] *  mix_policy
                occupancy_measure += current_sa_dist

                current_sasp_dist = current_sa_dist[:, :, np.newaxis] * P
                # We no need to calculate the value-weighted occupancy measure when the period at the end of the episode.
                if h == self.epLen - 1:
                    break
                value_weighted_occupancy_measure += current_sasp_dist * vVals[np.newaxis, np.newaxis, h + 1, :]

                current_s_dist = np.sum(current_sasp_dist, axis=(0, 1))#+1/(self.nState*self.epLen)*np.log(self.epIdx+1)/np.sqrt(self.epIdx)
                #s = np.sum(current_s_dist)
            #s=np.sum(occupancy_measure)
            #occupancy_measure*=2*self.epLen/(self.epLen+1)
            #occupancy_measure=occupancy_measure/np.sum(occupancy_measure)*self.epLen
            return occupancy_measure, value_weighted_occupancy_measure

        def model_posterior_gradient(model_param, zeta,mini_batch=False, flag="natural"):
            """

            Args:
                model_param: The parameter of MDP model, which is an S*A*(S+2) matrix
                zeta: control the scale of the entropy term
                mini_batch - bool - mini_batch gradient descent or not

                flag:
                    "none": Calculate the standard gradient
                    "natural":  Calculate the natural gradient

            Returns:
                The gradient of BOO object.
            """
            # Get the optimal policy of the current model
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            _, policy = get_optimal_policy(model_param, zeta)

            # Calculate the occupancy measure and value-weighted occupancy measure
            occupancy_measure, value_weighted_occupancy_measure = get_occupancy_measure(model_param, zeta)

            # get the standard gradient of value term,mean reward term and log posterior term
            value_grad = self.bayesian_model.value_gradient(model_param, occupancy_measure, value_weighted_occupancy_measure).reshape(self.nState, self.nAction, -1)
            value_grad[:, :, -1] -= self.epLen * np.log(1 + self.epIdx) / (self.nState * self.nAction) / (self.scaling * self.epIdx ** (1 - 0.5))  # The standard gradient of mean reward item

            # the posterior grad of the history data(the whole history or the mini_batch data)
            lambda_k = self.scaling * self.epIdx ** (- self.v1) * np.log(self.epIdx + 1) ** (-self.v2)
            posterior_grad = self.bayesian_model.log_posterior_gradient(model_param,mini_batch).reshape(self.nState, self.nAction, -1)*lambda_k
            grad = value_grad + posterior_grad

            # normal_grad = grad

            # Calculate the natural gradient
            if flag == "natural":
                tikhonov_regularizer = 0.01
                P, R = self.bayesian_model.to_tabular_MDP(model_param)
                F = P[:, :, :, np.newaxis] @ P[:, :, np.newaxis, :]
                diag = np.diagonal(F, axis1=-2, axis2=-1)
                diag.setflags(write=True)
                diag += P + tikhonov_regularizer
                #add 1 to avoid zero
                if mini_batch==False:
                    posterior_trans = self.bayesian_model.posterior_param[:, :, :-2] - self.bayesian_model.prior_param[:, :, :-2] + 1
                else:
                    posterior_trans = self.bayesian_model.batch_history[:, :, :-2] + 1
                total_count = np.sum(posterior_trans)
                # normalize occupancy measure
                occ = occupancy_measure / np.sum(occupancy_measure)
                # Fisher information matrix
                trans_fisher = (np.sum(posterior_trans, axis=-1) / total_count + occ)
                for s in range(self.nState):
                    for a in range(self.nAction):
                        F[s, a] = np.linalg.pinv(F[s, a] * trans_fisher[s, a])
                # Fisher information matrix
                if mini_batch == False:
                    reward_fisher = (self.bayesian_model.history[:, :, -3] + 1) / np.sum((self.bayesian_model.history[:, :, -3] + 1)) + tikhonov_regularizer
                else:
                    reward_fisher=(self.bayesian_model.batch_history[:, :, -3] + 1) / np.sum((self.bayesian_model.batch_history[:, :, -3] + 1)) + tikhonov_regularizer

                # natural gradient
                grad[:, :, :-1] = (F @ grad[:, :, :-1, np.newaxis])[:, :, :, 0]
                grad[:, :, -1] = 1 / (self.bayesian_model.precision * reward_fisher) * grad[:, :, -1]

            return grad.reshape(-1)

        def get_optimal_policy(model_param, zeta):
            """

            Args:
                model_param: The parameter of MDP model, which is an S*A*(S+2) matrix
                zeta:control the scale of the entropy term

            Returns:
                the optimal policy matrix and the optimal policy
            """
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            qVals, vVals = get_entropt_v(model_param, zeta)
            policy = np.zeros_like(self.policy_matrix)
            for h in reversed(range(self.epLen)):
                for s in range(self.nState):
                    policy[s, h, :] = np.exp((qVals[s, h, :] - vVals[h, s]) / zeta)
                    #policy[s,h,:]=policy[s,h,:]/np.sum(policy[s,h,:])
            # # Ensure forall s,a, policy(s,a)>=1e-9
            # policy = (policy + 1e-9) * (1 / (1 + self.nAction * 1e-9))
            policy_matrix = np.log(policy)
            policy_matrix = policy_matrix - np.mean(policy_matrix, axis=-1)[:, :, np.newaxis]
            return policy_matrix, policy

        def mini_batch_gradient_decent(model_param, max_evals=1):
            """
            mini_batch gradient decent
            Args:
                model_param: The parameter of MDP model, which is an S*A*(S+2) matrix
                max_evals: the iteration of gradient decent

            Returns:
                new model and new optimal policy
            """

            evals = 0
            value_variation = self.epIdx ** (self.l1) * np.log(self.epIdx + 1) ** (self.l2)
            zeta = self.zeta  / value_variation
            last_model_param = model_param.copy()



            #Divide the dataset by batch_size
            mini_batchs=self.bayesian_model.sample_batchdata(self.epLen)
            for mini_batch in mini_batchs:
                #update the history and posterior param according to the mini_batch data
                #t=time.time()
                self.bayesian_model.update_batch(mini_batch)

                #print("2",time.time()-t)
                # object = lambda model: optimal_object(model, zeta,mini_batch=True)
                # standard_grad = lambda model: model_posterior_gradient(model, zeta,mini_batch=True, flag="none")
                #calculate the current gradient of model
                #t = time.time()
                #t=time.time()

                #print("3",time.time() - t)
                last_obj=np.inf
                for evals in range(max_evals):
                    last_grad = -model_posterior_gradient(last_model_param, mini_batch=True, zeta=zeta)
                    new_model_param = last_model_param + last_grad * 0.001
                    new_obj=optimal_object(new_model_param, zeta, mini_batch=True)
                    if new_obj<last_obj:
                        last_model_param=new_model_param
                        last_obj=optimal_object(last_model_param, zeta, mini_batch=True)
                    else:
                        break
                    evals += 1
                    # break
                    # result = scipy.optimize.line_search(object, standard_grad, last_model_param, last_grad)
                    #
                    # # line search fail
                    # if result[0] == None:
                    #     break
                    #
                    # # update model
                    # last_model_param = last_model_param + last_grad * result[0]
                    #
                    # # Set ftol=1e-4 to speed up.
                    # if result[4] - result[3] < 1e-4:
                    #     break

            # Get the policy which will used in the environment, where entropy item should't affect policy too much.
            policy_matrix, policy_entropy = get_optimal_policy(last_model_param, 0.1 * zeta)

            return last_model_param, policy_entropy

        def gradient_decent(model_param, max_evals=1):
            """
            One step gradient decent
            Args:
                model_param: The parameter of MDP model, which is an S*A*(S+2) matrix
                max_evals: the iteration of gradient decent

            Returns:
                new model and new optimal policy
            """

            evals = 0
            value_variation = self.epIdx ** (self.l1) * np.log(self.epIdx + 1) ** (self.l2)
            zeta = self.zeta /value_variation

            last_model_param = model_param.copy()
            object = lambda model: optimal_object(model, zeta)
            standard_grad = lambda model: model_posterior_gradient(model, zeta, flag="none")
            while evals < max_evals:

                #calculate the current gradient of model
                last_grad = -model_posterior_gradient(last_model_param, zeta)
                obj=optimal_object(last_model_param,zeta)
                result = scipy.optimize.line_search(object, standard_grad, last_model_param, last_grad)

                # line search fail
                if result[0] == None:
                    break

                # update model
                last_model_param = last_model_param + last_grad * result[0]
                last_obj=optimal_object(last_model_param,zeta)

                # Set ftol=1e-4 to speed up.
                if result[4] - result[3] < 1e-4:
                    break
                evals += 1
            # Get the policy which will used in the environment, where entropy item should't affect policy too much.
            policy_matrix, policy_entropy = get_optimal_policy(last_model_param, 0.1 * zeta)
            save_model_inner_distribution(last_model_param,zeta)
            # if np.sum(policy_entropy)!=self.nState*self.epLen:
            #     bug=1

            return last_model_param, policy_entropy
        def save_model_inner_distribution(model_param,zeta):
            #calculate sah distribution
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            P, R = self.bayesian_model.to_tabular_MDP(model_param)
            occupancy_measure = np.zeros((self.nState, self.nAction))
            _, policy = get_optimal_policy(model_param, zeta)

            # initial state distribution
            current_s_dist = np.zeros(self.nState)
            #extral_sa_dist = self.model_inner_dist#self.bayesian_model.get_model_inner_distribution(self.model_inner_dist,self.epIdx)
            current_s_dist[0]=1
            model_inner_distribution=np.zeros((self.nState,self.epLen,self.nAction))
            for h in range(self.epLen):

                a = policy[:, h, :]
                # normal
                # if h==0:
                #     current_sa_dist =extral_sa_dist
                # else:
                current_sa_dist = current_s_dist[:, np.newaxis] * a
                #current_sa_dist = current_s_dist[:, np.newaxis] * a + 1 / self.epLen * extral_sa_dist[:, h, :]
                # epsilon-greedy,execute a uniform policy with a small probability
                # mix_policy=(a*(1-0.1)+0.1*np.ones((self.nState,self.nAction))*1/self.nAction)
                # current_sa_dist = current_s_dist[:, np.newaxis] *  mix_policy

                model_inner_distribution[:,h,:]=current_sa_dist
                occupancy_measure += current_sa_dist

                current_sasp_dist = current_sa_dist[:, :, np.newaxis] * P
                # We no need to calculate the value-weighted occupancy measure when the period at the end of the episode.
                if h == self.epLen - 1:
                    break
                current_s_dist = np.sum(current_sasp_dist, axis=(0, 1))  # +1/(self.nState*self.epLen)*np.log(self.epIdx+1)/np.sqrt(self.epIdx)

            #model_inner_distribution = model_inner_distribution/np.sum(model_inner_distribution)#*self.epLen
            model_inner_distribution=occupancy_measure/np.sum(occupancy_measure)
            #save model inner distribution
            # self.model_inner_dist[self.model_inner_dist_index]=model_inner_distribution
            # self.model_inner_dist_index=(self.model_inner_dist_index+1)%len(self.model_inner_dist)
            self.model_inner_dist=self.model_inner_dist*self.gamma+(1-self.gamma)*model_inner_distribution



        if self.bayesian_model.batch_size==0:
            result = gradient_decent(self.last_opt.reshape(-1))
        else:
            #mini_batch-gradient decent
            result = mini_batch_gradient_decent(self.last_opt.reshape(-1))

        # update model and policy
        self.last_opt = result[0]
        self.policy_entropy = result[1]

    def model_value_iteration(self):
        qVals, vVals = self.value_iteration(self.last_opt.reshape(self.nState, self.nAction, -1))
        return vVals[0, 0]


# -----------------------------------------------------------------------------
# Finite Bayesian Optimistic Optimization
# -----------------------------------------------------------------------------

class FiniteCBOO(FiniteHorizonFiniteMDPAgent):
    '''
    Finite Bayesian Optimistic Optimization for finite-horizon finite MDPs.
    '''

    def __init__(self, bayesian_model, epLen, delta=0.05, **kwargs):
        '''
        As per the tabular learner, but added tunable scaling.

        Args:
            delta - double - level of confidence
        '''
        self.delta = delta
        self.last_opt = bayesian_model.map_model()
        super().__init__(bayesian_model, epLen)

    def compute_hpd_quantile(self):
        '''
        Compute the HPD quantile
        '''
        # Compute the level of confidence
        alpha_k = min(self.delta / self.epIdx, 1.0)

        sample_size = math.ceil(100 * self.nState / alpha_k)
        hpd_quantile = self.bayesian_model.monte_carlo_quantile(self.bayesian_model.hpd_func, alpha_k, sample_size)

        return hpd_quantile

    def update_policy(self):
        '''
        Construct a credible region and solve for a Bayesian optimistic solution

        Works in place with no arguments.
        '''

        hpd_quantile = self.compute_hpd_quantile()

        bounds, model_constraints = self.bayesian_model.param_feasible_region()
        hpd_constraints = self.bayesian_model.hpd_constraints(hpd_quantile)

        def optimal_object(model_param):
            qVals, vVals = self.value_iteration(model_param.reshape(self.nState, self.nAction, -1))
            return - vVals[0, 0]

        def model_gradient(model_param):
            model_param = model_param.reshape(self.nState, self.nAction, -1)
            qVals, vVals = self.value_iteration(model_param)
            P, R = self.bayesian_model.to_tabular_MDP(model_param)

            occupancy_measure = np.zeros((self.nState, self.nAction))
            value_weighted_occupancy_measure = np.zeros((self.nState, self.nAction, self.nState))

            current_s_dist = np.zeros(self.nState)
            current_s_dist[0] = 1.0
            action_onehot = np.eye(self.nAction)

            for h in range(self.epLen - 1):
                a = np.argmax(qVals[:, h, :], axis=-1)
                a = action_onehot[a]
                current_sa_dist = current_s_dist[:, np.newaxis] * a
                occupancy_measure += current_sa_dist

                current_sasp_dist = current_sa_dist[:, :, np.newaxis] * P
                value_weighted_occupancy_measure += current_sasp_dist * vVals[np.newaxis, np.newaxis, h + 1, :]

                current_s_dist = np.sum(current_sasp_dist, axis=(0, 1))

            a = np.argmax(qVals[:, self.epLen - 1, :], axis=-1)
            a = action_onehot[a]
            current_sa_dist = current_s_dist[:, np.newaxis] * a
            occupancy_measure += current_sa_dist

            value_grad = self.bayesian_model.value_gradient(model_param, occupancy_measure, value_weighted_occupancy_measure)
            return value_grad

        # We use the last optimal model as a starting point, but it may cause the optimization stuck at a local optimum
        result = minimize(optimal_object, self.last_opt.reshape(-1), method='SLSQP', jac=model_gradient,
                          constraints=model_constraints + hpd_constraints, options={'ftol': 1e-3, 'disp': False},
                          bounds=bounds)

        self.last_opt = result.x
        qVals, vVals = self.value_iteration(self.last_opt.reshape(self.nState, self.nAction, -1))

        # Update the Agent's Q-values
        self.qVals = qVals
        self.vVals = vVals


# -----------------------------------------------------------------------------
# Finite Bayesian Optimistic Optimization with Approximate Credible Region
# -----------------------------------------------------------------------------

class FiniteCBOOACR(FiniteCBOO):
    '''
    Finite Bayesian Optimistic Optimization with Approximate Credible region for
    finite-horizon finite MDPs.
    '''

    def compute_hpd_quantile(self):
        '''
        Compute the HPD quantile
        '''
        model_param = self.bayesian_model.posterior_sampling()
        hpd_quantile = self.bayesian_model.hpd_func(model_param)
        self.last_opt = model_param

        return hpd_quantile


# -----------------------------------------------------------------------------
# Tabular Bayesian Optimistic Optimization
# -----------------------------------------------------------------------------

class TabularCBOO(FiniteHorizonTabularMDPAgent):
    '''
    Tabular Bayesian Optimistic Optimization for finite-horizon tabular MDPs.
    Use factorized transition and reward credible region
    '''

    def __init__(self, bayesian_model, epLen, delta=0.05, **kwargs):
        '''
        As per the tabular learner, but added tunable scaling.

        Args:
            delta - double - level of confidence
        '''
        self.delta = delta
        self.V = {}
        self.quantile = {}
        self.prob = {}
        super().__init__(bayesian_model, epLen)

    def update_obs(self, oldState, action, reward, newState, done):
        self.bayesian_model.update_history([(oldState, action, reward, newState, done), ])
        self.time_period += 1

        if not done:
            p = cp.Variable(self.nState)
            self.V[oldState, action] = cp.Parameter(self.nState)
            self.quantile[oldState, action] = cp.Parameter()

            objective = cp.Maximize(p @ self.V[oldState, action])
            constraints = [p >= 0, cp.sum(p) == 1,
                           cp.geo_mean(p, self.bayesian_model.posterior_param[oldState, action, :-2] - 1, max_denom=1073741824) >= self.quantile[oldState, action]]
            self.prob[oldState, action] = cp.Problem(objective, constraints)

        else:
            self.time_period = 0
            self.epIdx += 1
            self.update_policy()

    def compute_hpd_quantile(self):
        '''
        Compute the HPD quantile
        '''
        # Compute the level of confidence
        alpha_k = min(self.delta / self.epIdx, 1.0) / (2 * self.nState * self.nAction)

        # Compute the credible interval of the reward
        sample_size = math.ceil(100 * self.nState / alpha_k)

        hpd_quantile = self.bayesian_model.monte_carlo_quantile(self.bayesian_model.independent_hpd_func,
                                                                alpha_k, sample_size)

        return hpd_quantile

    def update_policy(self):
        '''
        Construct a credible region and solve for a Bayesian optimistic solution

        Works in place with no arguments.
        '''

        quantile = self.compute_hpd_quantile()

        P_quantile = quantile[0]
        R_opt = - quantile[1]

        # Solve the extended value iteration via optimistic optimization
        trans_posterior = self.bayesian_model.posterior_param[:, :, :-2]
        qVals, vVals = self.extended_value_iteration_opt(trans_posterior, P_quantile, R_opt)

        # Update the Agent's Q-values
        self.qVals = qVals
        self.vVals = vVals

    def extended_value_iteration_opt(self, trans_posterior, P_quantile, R_opt):
        '''
        Compute the Q values for a given R, P by extended value iteration using convex optimization

        Args:
            trans_posterior - Dirichlet transition posterior of size S*A*S
            P_quantile - transition HPD quantile of size S*A
            R_opt - optimistic reward of size S*A

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''
        qVals = np.repeat(R_opt[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        vVals[self.epLen - 1] = np.max(qVals[:, self.epLen - 1, :], axis=-1)

        for h in reversed(range(self.epLen - 1)):
            for s in range(self.nState):
                for a in range(self.nAction):
                    if np.all(trans_posterior[s, a] == 1):
                        optV = np.max(vVals[h + 1])
                    else:
                        self.V[s, a].value = vVals[h + 1]
                        self.quantile[s, a].value = P_quantile[s, a]
                        optV = self.prob[s, a].solve()

                    # Do Bellman backups with the optimistic R and P
                    qVals[s, h, a] += optV

            vVals[h] = np.max(qVals[:, h, :], axis=-1)

        return qVals, vVals


# -----------------------------------------------------------------------------
# Tabular Bayesian Optimistic Optimization with Approximate Credible Region
# -----------------------------------------------------------------------------

class TabularCBOOACR(TabularCBOO):
    '''
    Tabular Bayesian Optimistic Optimization with Approximate Credible region for
    finite-horizon tabular MDPs.
    '''

    def compute_hpd_quantile(self):
        '''
        Compute the HPD quantile
        '''
        model_param = self.bayesian_model.posterior_sampling()
        hpd_quantile = self.bayesian_model.independent_hpd_func(model_param)

        return hpd_quantile


# -----------------------------------------------------------------------------
# PSRL
# -----------------------------------------------------------------------------

class PSRL(FiniteHorizonFiniteMDPAgent):
    '''
    Posterior Sampling for Reinforcement Learning
    '''

    def update_policy(self):
        '''
        Sample a single MDP from the posterior and solve for optimal qVals values.

        Works in place with no arguments.
        '''
        # Sample the MDP
        model_param = self.bayesian_model.posterior_sampling()

        # Solve the MDP via value iteration
        qVals, vVals = self.value_iteration(model_param)
        map_logp = self.bayesian_model.log_posterior_probability(self.bayesian_model.map_model().reshape(self.nState, self.nAction, -1))
        self.logp_gap = map_logp - self.bayesian_model.log_posterior_probability(model_param.reshape(self.nState, self.nAction, -1))




        self.qVals = qVals
        self.vVals = vVals

    def model_value_iteration(self):
        return 0


# -----------------------------------------------------------------------------
# Optimistic PSRL
# -----------------------------------------------------------------------------

class OptimisticPSRL(PSRL):
    '''
    Optimistic Posterior Sampling for Reinforcement Learning
    '''

    def __init__(self, bayesian_model, epLen, nSamp=2, **kwargs):
        '''
        Just like PSRL but we sample multiple models and take the optimistic one

        Args:
            nSamp - int - number of samples to use for optimism
        '''
        self.nSamp = nSamp
        super().__init__(bayesian_model, epLen)

    def update_policy(self):
        '''
        Take multiple samples and then take the optimistic one.

        Works in place with no arguments.
        '''
        # Sample the MDP
        param_samples = self.bayesian_model.posterior_sampling(self.nSamp)
        qVals, vVals = self.value_iteration(param_samples[0])
        self.qVals = qVals
        self.vVals = vVals

        for i in range(1, self.nSamp):
            # Do another sample and take optimistic Q-values
            qVals, vVals = self.value_iteration(param_samples[i])
            if vVals[0, 0] >= self.vVals[0, 0]:
                # The assumption here is that the first state is the start state
                self.qVals = qVals
                self.vVals = vVals


# -----------------------------------------------------------------------------
# Optimistic Envelope PSRL
# -----------------------------------------------------------------------------

class OptimisticEnvelopePSRL(OptimisticPSRL):
    '''
    Optimistic Envelope Posterior Sampling for Reinforcement Learning
    '''

    def update_policy(self):
        '''
        Take multiple samples and then take the optimistic envelope.

        Works in place with no arguments.
        '''
        # Sample the MDP
        param_samples = self.bayesian_model.posterior_sampling(self.nSamp)
        qVals, vVals = self.value_iteration(param_samples[0])
        self.qVals = qVals
        self.vVals = vVals

        for i in range(1, self.nSamp):
            # Do another sample and take optimistic Q-values
            qVals, vVals = self.value_iteration(param_samples[i])
            self.vVals = np.maximum(vVals, self.vVals)
            self.qVals = np.maximum(qVals, self.qVals)


# -----------------------------------------------------------------------------
# BOSS
# -----------------------------------------------------------------------------

class BOSS(OptimisticPSRL):

    def update_policy(self):
        '''
        Take multiple samples and then solve a merged MDP.

        Works in place with no arguments.
        '''
        # Sample the MDP
        param_samples = self.bayesian_model.posterior_sampling(self.nSamp)
        qVals, vVals = self.optimistic_value_iteration(param_samples)
        self.qVals = qVals
        self.vVals = vVals

    def optimistic_value_iteration(self, model_params):
        '''
        Compute the qVals values for given model parameters

        Args:
            model_params - numpy.ndarray - set of sampled models

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''
        Ps = np.zeros((model_params.shape[0], self.nState, self.nAction, self.nState))
        R = - np.inf
        for i, param in enumerate(model_params):
            sampled_P, sampled_R = self.bayesian_model.to_tabular_MDP(param)
            R = np.maximum(R, sampled_R)
            Ps[i] = sampled_P

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        vVals[self.epLen - 1] = np.max(qVals[:, self.epLen - 1, :], axis=-1)

        for h in reversed(range(self.epLen - 1)):
            qVals[:, h, :] += np.max(Ps @ vVals[h + 1], axis=0)
            vVals[h] = np.max(qVals[:, h, :], axis=-1)

        return qVals, vVals


# -----------------------------------------------------------------------------
# Rational Optimistic PSRL
# -----------------------------------------------------------------------------

class ROPSRL(OptimisticPSRL):
    '''
    Rational Optimistic Posterior Sampling for Reinforcement Learning
    '''

    def update_policy(self):
        '''
        Take multiple samples and then take the optimistic one.

        Works in place with no arguments.
        '''
        map_model = self.bayesian_model.map_model()
        qVals, vVals = self.value_iteration(map_model)
        self.qVals = qVals
        self.vVals = vVals

        # Sample the MDP
        param_samples = self.bayesian_model.posterior_sampling(self.nSamp)

        for i in range(self.nSamp):
            qVals, vVals = self.value_iteration(param_samples[i])
            self.vVals = np.maximum(vVals, self.vVals)
            self.qVals = np.maximum(qVals, self.qVals)


# -----------------------------------------------------------------------------
# UCRL2
# -----------------------------------------------------------------------------

class UCRL2(FiniteHorizonTabularMDPAgent):
    '''Classic benchmark optimistic algorithm'''

    def __init__(self, bayesian_model, epLen, delta=0.05, scaling=1.0, **kwargs):
        '''
        Initialize the UCRL2 agent.

        Args:
            delta - double - probability scale parameter
            scaling - double - rescale default confidence sets
        '''
        self.delta = delta
        self.scaling = scaling
        super().__init__(bayesian_model, epLen)

    def get_slack(self):
        '''
        Returns the slackness parameters for UCRL2

        Args:
            time - int - grows the confidence sets

        Returns:
            R_slack - R_slack[s, a] is the confidence width for UCRL2 reward
            P_slack - P_slack[s, a] is the confidence width for UCRL2 transition
        '''

        L = np.log(self.nState * self.nAction * (self.epIdx + 1) / self.delta)
        n = np.maximum(np.sum(self.bayesian_model.history[:, :, :-2], axis=-1), 1)

        R_slack = self.scaling * np.sqrt(L / n)
        P_slack = self.scaling * np.sqrt(self.nState * L / n)

        return P_slack, R_slack

    def update_policy(self):
        '''
        Compute UCRL2 Q-values via extended value iteration.
        '''
        # Output the MAP estimate MDP
        map_model = self.bayesian_model.map_model()

        # Compute the slack parameters
        P_slack, R_slack = self.get_slack()

        # Perform extended value iteration
        qVals, vVals = self.extended_value_iteration(map_model, P_slack, R_slack)

        self.qVals = qVals
        self.vVals = vVals


# -----------------------------------------------------------------------------
# Epsilon-Greedy
# -----------------------------------------------------------------------------

class EpsilonGreedy(FiniteHorizonTabularMDPAgent):
    '''Epsilon greedy agent'''

    def __init__(self, bayesian_model, epLen, epsilon=0.1, **kwargs):
        '''
        An epsilon-greedy learner

        Args:
            delta - double - probability scale parameter
            scaling - double - rescale default confidence sets
        '''
        self.epsilon = epsilon
        super().__init__(bayesian_model, epLen)

    def update_policy(self):
        '''
        Compute UCRL Q-values via extended value iteration.

        Args:
            time - int - grows the confidence sets
        '''
        # Output the MAP estimate MDP
        map_model = self.bayesian_model.map_model()

        # Solve the MDP via value iteration
        qVals, vVals = self.value_iteration(map_model)

        # Update the Agent's Q-values
        self.qVals = qVals
        self.vVals = vVals

    def pick_action(self, state):
        '''
        Default is to use egreedy for action selection
        '''
        action = self.egreedy(state, self.epsilon)
        return action
class BOOviaPS(PSRL):
    '''
    Balanced Posterior Sampling for Reinforcement Learning
    '''

    def __init__(self, bayesian_model, epLen, nSamp=2, scaling=1.0, alpha=0.5, **kwargs):
        '''
        Just like PSRL but we sample multiple models and take the optimistic one
        Args:
            nSamp - int - number of samples to use for optimism
        '''
        self.nSamp = nSamp
        self.scaling = scaling
        self.alpha = alpha
        super().__init__(bayesian_model, epLen)

    def update_policy(self):
        '''
        Take multiple samples and then take the optimistic one.
        Works in place with no arguments.
        '''
        Q = np.zeros((self.nSamp, self.nState, self.epLen, self.nAction))
        V = np.zeros((self.nSamp, self.epLen, self.nState))

        # Sample MDPs
        param_samples = self.bayesian_model.posterior_sampling(self.nSamp)
        log_prob = np.array([self.bayesian_model.log_posterior_probability(param) for param in param_samples])

        for i in range(self.nSamp):
            qVals, vVals = self.value_iteration(param_samples[i])
            Q[i] = qVals
            V[i] = vVals

        val_part = self.scaling * V[:, 0, 0] * self.epIdx ** (1 - self.alpha)
        prob_part = log_prob

        priority = val_part + np.log(self.epIdx+1)*prob_part
        best = np.argmax(priority)

        self.vVals = V[best]
        self.qVals = Q[best]