import numpy as np
from scipy.special import softmax, logsumexp
from scipy.optimize import brentq

def random_argmax(a):
    '''
    like np.argmax, but returns a random index in the case of ties
    '''
    return np.random.choice(np.flatnonzero(a == a.max()))


class EGreedyPolicy(object):
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon

    def get_p(self, q):
        """
        Return probability distribution p over actions representing a stochastic policy
        
        arguments:
            q: values for each action for a fixed state
        
        returns:
            p: probability of each action
        """

        q = q.flatten()
        num_actions = q.shape[0]
        p = np.ones(q.shape) * (self.epsilon/num_actions)
        
        argmax_a = random_argmax(q)
        p[argmax_a] += (1 - self.epsilon) 

        return p

    def get_action(self, q):
        """
        Select action accoring to policy

        arguments:
            q: values for each action for a fixed state
        
        returns:
            a: the selected action
        """

        p = self.get_p(q)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a


class SoftmaxPolicy(object):
    def __init__(self, temp=1, normalize=False, g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.temp = temp
        self.normalize = normalize
        if normalize == 'return':
            self.g_min = g_min
            self.g_max = g_max
        if normalize == 'td':
            self.td_epsilon = td_epsilon
            self.td_step_size = td_step_size
            self.zeta = zeta # running average of squared td error

    
    def get_p(self, q):
        """
        Return probability distribution p over actions representing a stochastic policy
        
        arguments:
            q: values for each action for a fixed state
        
        returns:
            p: probability of each action
        """
        q = q.flatten()
        if self.normalize == 'return':
            temp_ = self.temp/(self.g_max-self.g_min)
        if self.normalize == 'td':
            temp_ = self.temp/ (np.abs(self.zeta) + self.td_epsilon)
        else:
            temp_ = self.temp

        p = softmax((1/temp_)*q) # probability of choosing each action

        return p
   
    def get_action(self, q):
        """
        Select action accoring to policy

        arguments:
            q: values for each action for a fixed state
        
        returns:
            a: the selected action
        """

        p = self.get_p(q)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a


    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        self.zeta = (1-self.td_step_size)*self.zeta + self.td_step_size*delta


class GPolicy(object):
    def __init__(self, temp=1, normalize=False, g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.temp = temp
        self.normalize = normalize
        if normalize == 'return':
            self.g_min = g_min
            self.g_max = g_max
        if normalize == 'td':
            self.td_epsilon = td_epsilon
            self.td_step_size = td_step_size
            self.zeta = zeta  # running average of squared td error

    def get_p(self, q):
        """
        Return probability distribution p over actions representing a stochastic policy

        arguments:
            q: values for each action for a fixed state

        returns:
            p: probability of each action
        """
        q = q.flatten()
        if self.normalize == 'return':
            temp_ = self.temp / (self.g_max - self.g_min)
        if self.normalize == 'td':
            temp_ = self.temp / (np.abs(self.zeta) + self.td_epsilon)
        else:
            temp_ = self.temp

        p = softmax((1 / temp_) * q)  # probability of choosing each action

        return p

    def g(self, q):
        q = q.flatten()
        rho =1/len(q)
        update = np.log(rho) + logsumexp(self.temp * q)
        # normalizing by beta
        return 1/self.temp * update
               # np.sum(rho* np.exp(-self.temp * q))


    def get_action(self, q):
        """
        Select action accoring to policy

        arguments:
            q: values for each action for a fixed state

        returns:
            a: the selected action
        """

        p = self.get_p(q)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p)
        return a

    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * delta


class ResMaxPolicy(object):
    def __init__(self, eta=12, normalize=False, g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.eta = eta
        self.normalize = normalize
        if normalize == 'return':
            self.g_min = g_min
            self.g_max = g_max
        if normalize == 'td':
            self.td_epsilon = td_epsilon
            self.td_step_size = td_step_size
            self.zeta = zeta # running average of squared td error


    def get_p(self, q):
        """
        Return probability distribution p over actions representing a stochastic policy 
        (used for action selection)
        
        arguments:
            q: values for each action for a fixed state
        
        returns:
            p: probability of each action
        """

        q = q.flatten()
        num_actions = q.shape[0]
        
        q_max = np.max(q)
        argmax_a = random_argmax(q) 

        if self.normalize == 'return': # old implementation NOT UPDATED
            normalized_eta = self.eta/(self.g_max-self.g_min)
            p =  1 / (num_actions + normalized_eta*(q_max - q))
        elif self.normalize == 'td': #o ld implementation NOT UPDATED
            normalized_eta = self.eta/ (np.abs(self.zeta) + self.td_epsilon)
            p = 1 / (num_actions + normalized_eta*(q_max - q))
        elif self.normalize == 'non-expansion': #icurrent  implementation
            p = 1 / (num_actions*np.max([q_max-np.min(q),1])  + (1/self.eta) * (q_max-q))
        else:
            p = 1 / (num_actions + (1/self.eta)*(q_max-q))

        p[argmax_a] = 1 - np.sum(np.delete(p, [argmax_a]))

        return p 


    def get_action(self, q):
        """
        Select action accoring to policy

        arguments:
            q: values for each action for a fixed state
        
        returns:
            a: the selected action
        """
        p = self.get_p(q)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a

    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        self.zeta = (1-self.td_step_size)*self.zeta + self.td_step_size*delta


class MellowMaxPolicy(object):
    '''
    from https://arxiv.org/pdf/1612.05628.pdf
    '''
    def __init__(self, omega=1):
        self.omega = omega

    def get_p(self, q):
        """
        Return probability distribution p over actions representing a stochastic policy
        
        arguments:
            q: values for each action for a fixed state
        
        returns:
            p: probability of each action
        """

        q = q.flatten()

        def f(beta):
            '''
            function for brent q to optimize
            '''
            nonlocal q
            mm = self.mellow_max(q)
            return np.sum(  
                np.multiply(
                    np.exp(beta*(q-self.mellow_max(q))),
                    (q-mm)))

        bounds = np.array([-1e8, 1e8])

        while True:
            try:
                beta = brentq(f, bounds[0], bounds[1])
                break
            except:
                bounds *= 10
        p = softmax(beta*q)
        return p 


    def mellow_max(self, q):
        c = np.max(q) # to avoid overflow
        return c + np.log(np.mean(np.exp((1/self.omega)*(q-c))))/(1/self.omega)


    def get_action(self, q):
        """
        Select action accoring to policy

        arguments:
            q: values for each action for a fixed state
        
        returns:
            a: the selected action
        """

        try:
            p = self.get_p(q)
            p /= p.sum() # normalize for fp issues with sampling
            num_actions = q.shape[0]
            a = np.random.choice(num_actions, p=p) 
        except:
            print(p)
            assert False
        return a
