import numpy as np
from scipy.special import softmax
from scipy.optimize import brentq
from utils.snippets import RootFinder


class DefaultPolicy(object):
    def update_state(self, alpha, delta, state):
        pass

class EGreedyPolicy(DefaultPolicy):
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon

    def get_p(self, q, exp_value=0.1):
        """
        Return probability distribution p over actions representing a stochastic policy

        q: values for each action for a fixed state
        """        
        self.epsilon = exp_value
        q = q.flatten()
        num_actions = q.shape[0]
        p = np.ones(q.shape) * (self.epsilon/num_actions)
        
        argmax_a = random_argmax(q)
        p[argmax_a] += (1 - self.epsilon) 

        return p

    def get_action(self, q, exp_value=0.1):
        p = self.get_p(q, exp_value=exp_value)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a


class SoftmaxPolicy(DefaultPolicy):
    def __init__(self, temp=1, normalization="none", g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.temp = temp
        self.normalization = normalization
        self.g_min = g_min
        self.g_max = g_max
        self.td_epsilon = td_epsilon
        self.td_step_size = td_step_size
        self.zeta = zeta
    
    def get_p(self, q, exp_value=0.1):
        """
        Return probability distribution p over actions representing a stochastic policy

        q: values for each action for a fixed state
        """
        self.temp = exp_value
        q = q.flatten()

        if self.normalization == "fixed":
            normalized_temp = self.temp / (self.g_max-self.g_min)
        elif self.normalization == "td_squared":
            normalized_temp = self.temp / (self.zeta + self.td_epsilon)
        elif self.normalization == "td_absolute":
            normalized_temp = self.temp / (np.abs(self.zeta) + self.td_epsilon)
        else: # self.normalization == "none":
            normalized_temp = self.temp
        p = softmax(normalized_temp*q) # probability of choosing each action
        return p
   
    def get_action(self, q, exp_value=0.1):
        p = self.get_p(q, exp_value=exp_value)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a
    
    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        if self.normalization == "td_squared":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * (delta**2)
        elif self.normalization == "td_absolute":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * delta
        

class MellowmaxPolicy(DefaultPolicy):
    def __init__(self, omega=1, normalization="none", g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.omega = omega
        self.normalization = normalization
        self.g_min = g_min
        self.g_max = g_max
        self.td_epsilon = td_epsilon
        self.td_step_size = td_step_size
        self.zeta = zeta
        self.root_finder = RootFinder()
    
    def get_p(self, q, exp_value=0.1):
        """
        Return probability distribution p over actions representing a stochastic policy

        q: values for each action for a fixed state
        """
        self.omega = exp_value
        q = q.flatten()

        temp = self.root_finder.mellow_max_root_finder(q, self.omega)

        if self.normalization == "fixed":
            normalized_temp = temp / (self.g_max-self.g_min)
        elif self.normalization == "td_squared":
            normalized_temp = temp / (self.zeta + self.td_epsilon)
        elif self.normalization == "td_absolute":
            normalized_temp = temp / (np.abs(self.zeta) + self.td_epsilon)
        else: # self.normalization == "none":
            normalized_temp = temp
        p = softmax(normalized_temp*q) # probability of choosing each action
        return p
   
    def get_action(self, q, exp_value=0.1):
        p = self.get_p(q, exp_value=exp_value)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a
    
    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        if self.normalization == "td_squared":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * (delta**2)
        elif self.normalization == "td_absolute":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * delta


class ResMaxPolicy(DefaultPolicy):
    def __init__(self, eta=12, normalization="none", g_min=0, g_max=1, td_epsilon=10e-9, td_step_size=0.1, zeta=0):
        self.eta = eta
        self.normalization = normalization
        self.g_min = g_min
        self.g_max = g_max
        self.td_epsilon = td_epsilon
        self.td_step_size = td_step_size
        self.zeta = zeta # running average of squared td error
    
    def get_p(self, q, exp_value=0.1):
        """
        Return probability distribution p over actions representing a stochastic policy

        q: values for each action for a fixed state
        """
        self.eta = exp_value
        q = q.flatten()
        num_actions = q.shape[0]
        
        q_max = np.max(q)
        argmax_a = random_argmax(q)

        if self.normalization == "fixed":
            normalized_eta = self.eta/(self.g_max-self.g_min)
        elif self.normalization == "td_squared":
            normalized_eta = self.eta / (self.zeta + self.td_epsilon)
        elif self.normalization == "td_absolute":
            normalized_eta = self.eta / (np.abs(self.zeta) + self.td_epsilon)
        else: # self.normalization == "none":
            normalized_eta = self.eta
        p = 1 / (num_actions + normalized_eta*(q_max - q))
        p[argmax_a] = 1 - np.sum(np.delete(p, [argmax_a]))

        return p 

    def get_action(self, q, exp_value=0.1):
        p = self.get_p(q, exp_value=exp_value)
        num_actions = q.shape[0]
        a = np.random.choice(num_actions, p=p) 
        return a

    def update_zeta(self, delta):
        """
        Update zeta when a new td error is received

        arguments:
            delta: td_error
        """
        if self.normalization == "td_squared":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * (delta**2)
        elif self.normalization == "td_absolute":
            self.zeta = (1 - self.td_step_size) * self.zeta + self.td_step_size * delta

def random_argmax(a):
    '''
    like np.argmax, but returns a random index in the case of ties
    '''
    return np.random.choice(np.flatnonzero(a == a.max()))