import gym
from gym import spaces
import numpy as np
from collections import defaultdict 

"""helper functions returning the corresponding probabilities for given q values"""

def boltzmann_probs(q_values, temp=1.0):
    scaled_q = q_values-max(q_values)
    exp=np.exp(scaled_q/temp)
    probs = exp / np.sum(exp)
    return probs

def epsilon_greedy_probs(q_values, epsilon=0.0):
    nA = q_values.shape[0]
    probs = np.zeros(nA)
    probs[np.argmax(q_values)] = 1.0 - epsilon
    probs += epsilon / nA
    return probs

def greedy_probs(q_values):
    nA = q_values.shape[0]
    probs = np.zeros(nA)
    probs[np.argmax(q_values)] = 1.0
    return probs

"""helper functions returning the corresponding policies for a given Q table"""

def greedy_policy(Q):
    num_actions = Q.shape[1]
    return np.eye(num_actions)[np.argmax(Q, axis=1)]

def boltzmann_policy(Q, temp=0.0):
    scaled_Q = Q-np.max(Q, axis=1)[:, np.newaxis]
    logits = scaled_Q / temp
    exp_logits = np.exp(logits)
    probabilities = exp_logits / np.sum(exp_logits, axis=1)[:, np.newaxis]
    return probabilities

def epsilon_greedy_policy(Q, epsilon=0.0):
    num_actions = Q.shape[1]
    policy = np.ones_like(Q) * epsilon / num_actions
    greedy = greedy_policy(Q) * (1 - epsilon)
    return policy + greedy
    
class Q_Learning:
    """
    Implements Q learning

    Input attributes:
        n_states: number of states
        n_actions: number of actions
        alpha: learning rate
        discount: discount factor
        exploration: type of exploration [boltzmann, epsilon-greedy]
        expl_parameter: either the temperature of epsilon greedy exploration parameter

    Other attributes:
        Q: Q table 
        expl_policy: callable function implementing the exploration policy
        task_policy: callable function implementing the greedy/exploit policy
    """
    def __init__(self, n_states, n_actions, alpha=0.1, discount=0.9, exploration='boltzmann', expl_parameter=0.05):
        self.n_states = n_states
        self.n_actions = n_actions
        self.alpha = alpha
        self.discount = discount
        self.Q = np.zeros((self.n_states, self.n_actions))

        self.exploration = exploration

        if self.exploration == 'boltzmann':
            self.expl_policy = boltzmann_probs
        elif self.exploration == 'epsilon-greedy':
            self.expl_policy = epsilon_greedy_probs
        else:
            raise NotImplementedError

        self.expl_parameter = expl_parameter

        self.task_policy = greedy_probs

    def policy(self, state, mode='explore'):
        """return the policy/probabilites for a given state"""
        if mode == 'explore':
            return self.expl_policy(self.Q[state], self.expl_parameter)
        elif mode == 'exploit':
            return self.task_policy(self.Q[state])
        else:
            raise NotImplementedError

    def get_policy(self, mode='explore'):
        """return the full policy over the entire state space"""
        if mode == 'explore':
            if self.exploration == 'boltzmann':
                return boltzmann_policy(self.Q, temp=self.expl_parameter)
            elif self.exploration == 'epsilon-greedy':
                return epsilon_greedy_policy(self.Q, epsilon=self.expl_parameter)
            else:
                raise NotImplementedError
        elif mode == 'exploit':
            return greedy_policy(self.Q)
        else:
            raise NotImplementedError
        
    def step(self, state, mode='explore'):
        """sample an action from the policy"""
        probs = self.policy(state, mode=mode)
        return np.random.choice(self.n_actions, p=probs)

    def update(self, tup):
        """update the Q table given a tuple of experience"""
        state, action, reward, next_state, done = tup
        current_Qs = self.Q[next_state]
        self.Q[state, action] = (1 - self.alpha) * self.Q[state, action] \
        + self.alpha * (reward + (1 - done) * self.discount * np.max(current_Qs))
                



