from abc import abstractmethod

import numpy as np
import pdb
from solver.policy.base import FeedbackPolicy

def normalize(array):
    return array / array.sum()

class FiniteFeedbackPolicy(FeedbackPolicy):
    """
    Implements a finite action space feedback policy.
    """

    def __init__(self, state_space, action_space):
        super().__init__(state_space, action_space)

    def act(self, t, x):
        """
        At time t, act on observation x to obtain random action u
        :param t: time
        :param x: observation
        :return: action
        """
        pmf = self.pmf(t, x)
        return np.random.choice(range(len(pmf)), 1, p=pmf).item()

    @abstractmethod
    def pmf(self, t, x):
        """
        At time t, act on observation x to obtain action pmf
        :param t: time
        :param x: observation
        :return: action pmf
        """
        pass


class QMaxPolicy(FiniteFeedbackPolicy):
    def __init__(self, state_space, action_space, Qs):
        super().__init__(state_space, action_space)
        self.Qs = Qs

    def pmf(self, t, x):
        unit_vec = np.zeros(self.action_space.n)
        unit_vec[np.argmax(self.Qs[t][x])] = 1
        return unit_vec


class QSoftMaxPolicy(FiniteFeedbackPolicy):
    def __init__(self, state_space, action_space, Qs, tau, policy_array=None):
        super().__init__(state_space, action_space)
        self.Qs = Qs
        self.tau = tau
        self.policy_array = policy_array
    
    def update_policy_array(self, t, x, prob):
        if self.policy_array is not None:
            self.policy_array[t,x] = prob

    def pmf(self, t, x):
        expQs_norm = np.exp(self.tau * (np.array(self.Qs[t][x]) - max(self.Qs[t][x])))
        prob = normalize(expQs_norm)
        self.update_policy_array(t, x, prob)
        return prob
    
class ModifiedSoftmaxPolicy(QSoftMaxPolicy):
    def __init__(self, state_space, action_space, etaQs, tau, pi_array, sigma_array=None):
        super().__init__(state_space, action_space, etaQs, tau)
        self.pi_array = pi_array
        self.sigma_array = sigma_array

    def pmf(self, t, x):
        likelyhood = self.pi.pmf(t,x) * np.exp(self.Qs[t][x])
        prob = normalize(likelyhood)
        self.update_policy_array(t, x, prob)
        return prob

class RegularizedPolicy(ModifiedSoftmaxPolicy):
    def __init__(self, *args, reg_param=0., lr=.1):
        super().__init__(*args)
        self.lambda_eta = reg_param * lr
    
    def pmf(self, t, x):
        likelyhood = self.pi.pmf(t,x)**(1-self.lambda_eta)
        likelyhood *= self.sigma.pmf(t,x)**self.lambda_eta
        likelyhood *= np.exp(self.Qs[t][x])
        return normalize(likelyhood)
