import pandas as pd
import numpy as np


class RLAgent:
    def __init__(self, agent_type, states, actions, policy=None, use_average_reward=True,
                 policy_type='tabular', policy_softmax_tau=1.0, policy_update_type='stochastic_gradient_descent',
                 value_type='tabular', value_update_type='stochastic_gradient_descent',
                 beta_m_adam=0.9, beta_v_adam=0.999, epsilon_adam=1e-8,
                 initial_avg_reward=0.0, use_cvar=True, var_quantile=0.1, initial_var_reward=0.0):

        # initialize instance variables
        self.agent_type = agent_type

        self.states = states
        self.actions = actions

        # average reward
        self.use_avg_reward = use_average_reward
        if self.use_avg_reward:
            self.avg_reward = initial_avg_reward

        self.use_cvar = use_cvar
        self.var_quantile = var_quantile
        self.var_reward = initial_var_reward

        #################################
        # state and state-action values
        #################################
        self.value_type = value_type
        self.value_update_type = value_update_type

        # tabular case
        if self.value_type == 'tabular':
            self.values = {}
            if self.agent_type in ['td']:
                for state in self.states:
                    self.values[state] = 0

            elif self.agent_type in ['q_learning', 'sarsa', 'expected_sarsa']:
                for state in self.states:
                    self.values[state] = {}
                    for action in self.actions:
                        self.values[state][action] = 0

        # function approximation
        else:
            self.beta_m = beta_m_adam
            self.beta_v = beta_v_adam
            self.beta_m_product = beta_m_adam
            self.beta_v_product = beta_v_adam
            self.epsilon_adam = epsilon_adam

            self.weights_value = {}
            self.m_value = {}
            self.v_value = {}

            # initialize weights
            self.weights_value, self.m_value, self.v_value = self.initialize_weights(weights_type='value_approx')

        ################
        # agent policy
        ################
        self.policy_type = policy_type
        self.policy_softmax_tau = policy_softmax_tau
        self.policy_update_type = policy_update_type
        self.policy_discount = None

        # tabular
        if self.policy_type == 'tabular':
            self.policy = policy

        # softmax (non-parameterized)
        elif self.policy_type == 'softmax':
            # no need to do anything :)
            pass

        # softmax parameterized
        else:
            if pd.isnull(value_type):
                # initialize adam update params if we haven't done so yet
                self.beta_m = beta_m_adam
                self.beta_v = beta_v_adam
                self.beta_m_product = beta_m_adam
                self.beta_v_product = beta_v_adam
                self.epsilon_adam = epsilon_adam

            self.weights_policy = {}
            self.m_policy = {}
            self.v_policy = {}

            # initialize weights
            self.weights_policy, self.m_policy, self.v_policy = self.initialize_weights(weights_type='policy')

    def initialize_weights(self, weights_type):
        # initialize weights
        if weights_type == 'value_approx' and self.agent_type == 'td':
            value_dimension = 1
        else:
            value_dimension = len(self.actions)

        if weights_type == 'value_approx':
            model_type = self.value_type
        elif weights_type == 'policy':
            model_type = self.policy_type

        weights = {}
        m = {}
        v = {}
        if model_type == 'linear':
            weights[weights_type] = {
                'W': np.zeros((value_dimension, len(self.states))),
            }

            m[weights_type] = {
                'W': np.zeros((value_dimension, len(self.states))),
            }

            v[weights_type] = {
                'W': np.zeros((value_dimension, len(self.states))),
            }

        return weights, m, v

    @staticmethod
    def argmax(values):
        # get argmax of values, breaking ties arbitrarily

        top_value = float("-inf")
        ties = []
        for action in values.keys():
            q = values[action]
            if q > top_value:
                top_value = q
                ties = [action]
            elif q == top_value:
                ties.append(action)
        return np.random.choice(ties)

    def get_value(self, state, action=None):
        # returns the state or state-action value

        # tabular case
        if self.value_type == 'tabular':
            if self.agent_type in ['td']:
                return self.values[state]
            elif self.agent_type in ['q_learning', 'sarsa', 'expected_sarsa']:
                return self.values[state][action]

        # value-function approximation
        else:
            if not pd.isnull(action):
                value_index = self.actions.index(action)
            else:
                value_index = 0

            if self.value_type == 'linear':
                return np.dot(self.weights_value['value_approx']['W'][value_index], state)[0]

    def get_target_value(self, state, action=None):
        # returns the target state or state-action value
        return self.get_value(state, action)

    def get_softmax_probabilities(self, state):
        # computes the state-action preferences
        state_action_preferences = []
        for a in range(len(self.actions)):
            if self.policy_type == 'softmax':
                # if no parameterized policy, use action-state values as action preferences
                state_action_preferences.append(self.get_value(state, self.actions[a]) / self.policy_softmax_tau)

            elif self.policy_type == 'linear':
                state_action_preferences.append(np.dot(self.weights_policy['policy']['W'][a], state)[0])

        # Set the constant c by finding the maximum of state-action preferences
        c = np.max(state_action_preferences)

        # get numerator by subtracting c from state-action preferences and exponentiating it
        numerator = []
        for h in state_action_preferences:
            numerator.append(np.exp(h - c))

        # get denominator by summing the values in the numerator
        denominator = np.sum(numerator)

        # get action probs by dividing each element in numerator array by denominator
        softmax_probs = []
        for n in numerator:
            softmax_probs.append(n / denominator)

        # get policy dictionary
        i = 0
        softmax_dict = {}
        for a in self.actions:
            softmax_dict[a] = softmax_probs[i]
            i += 1

        return softmax_dict

    def get_policy(self, state, action, argmax_action=None, epsilon=None, softmax_probs=None):
        # returns the probability of taking action A at state S

        if self.policy_type == 'tabular':
            if pd.isnull(self.policy):
                # epsilon greedy policy
                if action == argmax_action:
                    return (1 - epsilon) + (epsilon / len(self.actions))
                else:
                    return epsilon / len(self.actions)
            else:
                # user-defined tabular policy
                return self.policy[state][action]
        else:
            # softmax, possibly parameterized, policy
            return softmax_probs[action]

    def choose_action_from_policy(self, state, epsilon):
        # chooses action based on policy

        if self.policy_type == 'tabular':
            if not pd.isnull(self.policy):
                # choose action using user-defined tabular policy
                action_probs = []
                for a in self.actions:
                    action_probs.append(self.policy[state][a])
                action = np.random.choice(self.actions, p=action_probs)
            else:
                # choose action using epsilon greedy policy
                if np.random.random() <= epsilon:
                    action = np.random.choice(self.actions)
                else:
                    # tabular case
                    if self.value_type == 'tabular':
                        values = self.values[state]

                    # function approximation
                    else:
                        values = {}
                        for action in self.actions:
                            values[action] = self.get_value(state, action)

                    action = self.argmax(values)
        else:
            # choose action using softmax, possibly parameterized, policy
            action = np.random.choice(self.actions, p=list(self.get_softmax_probabilities(state).values()))

        return action

    def get_gradient(self, state, weights, gradient_type, weight_type, action=None):
        # get the value-function approximation gradient

        if weight_type == 'value_approx' and self.agent_type == 'td':
            value_index = 0
        else:
            value_index = self.actions.index(action)

        if gradient_type == 'linear':
            gradient = np.zeros(weights[weight_type]['W'].shape)
            gradient[value_index] = state.T
            gradients = {
                weight_type: {
                    'W': gradient,
                }
            }
            return gradients

    def get_softmax_policy_gradient(self, state, action, weights, gradient_type, weight_type):
        # calculate gradient of softmax, parameterized policy

        # get softmax probabilities
        softmax_probs = self.get_softmax_probabilities(state)

        # get gradient of action preferences parameterization
        all_gradients = {}
        for a in self.actions:
            all_gradients[a] = self.get_gradient(state=state,
                                                 action=a,
                                                 weights=weights,
                                                 gradient_type=gradient_type,
                                                 weight_type=weight_type)

        # get softmax gradient
        # grad_ln(pi(a | s, theta)) = grad_preferences(s, a, theta) - sum_i[pi(i)*grad_preferences(s, i, theta)]
        gradients = {}
        for item in all_gradients[action].keys():
            gradients[item] = {}
            for param in all_gradients[action][item].keys():
                gradients[item][param] = np.array(all_gradients[action][item][param])

                # subtract terms
                for action_index in range(len(self.actions)):
                    gradients[item][param][action_index] -= softmax_probs[self.actions[action_index]] * all_gradients[self.actions[action_index]][item][param][action_index]

        return gradients

    def update_weights(self, weights_type, weights, step_size, delta, gradient, m, v, update_type, policy_discount=None):

        # update weights
        if weights_type == 'policy':
            gradient = policy_discount * gradient

        use_gradient = gradient

        if update_type == 'stochastic_gradient_descent':
            weights = weights + step_size * delta * use_gradient
            m = m
            v = v

        elif update_type == 'adam':
            # compute g
            g = delta * use_gradient

            # update m and v
            m = self.beta_m * m + (1 - self.beta_m) * g
            v = self.beta_v * v + (1 - self.beta_v) * (g * g)

            # compute m_hat and v_hat
            m_hat = m / (1 - self.beta_m_product)
            v_hat = v / (1 - self.beta_v_product)

            # update weights
            weights = weights + step_size * m_hat / (np.sqrt(v_hat) + self.epsilon_adam)

        return weights, m, v

    def value_tabular_update(self, last_state, last_action, step_size, estimate, target):
        # update tabular values
        delta = target - estimate
        if self.agent_type in ['td']:
            self.values[last_state] = self.values[last_state] + step_size['value'] * delta

        elif self.agent_type in ['q_learning', 'sarsa', 'expected_sarsa']:
            self.values[last_state][last_action] = self.values[last_state][last_action] + step_size['value'] * delta
        return

    def value_approx_update(self, last_state, last_action, step_size, estimate, target):
        # value function approximation update

        # calculate gradient and then update weights
        delta = target - estimate

        gradient = self.get_gradient(state=last_state,
                                     action=last_action,
                                     weights=self.weights_value,
                                     gradient_type=self.value_type,
                                     weight_type='value_approx')

        # update weights
        if self.value_type == 'linear':
            if self.agent_type == 'td':
                value_index = slice(0, 1)
            else:
                value_index = slice(self.actions.index(last_action), self.actions.index(last_action) + 1)

        for item in self.weights_value.keys():
            for param in self.weights_value[item].keys():
                w, m, v = self.update_weights(weights_type='value',
                                              weights=self.weights_value[item][param][value_index],
                                              step_size=step_size['value'],
                                              delta=delta,
                                              gradient=gradient[item][param][value_index],
                                              m=self.m_value[item][param][value_index],
                                              v=self.v_value[item][param][value_index],
                                              update_type=self.value_update_type)

                self.weights_value[item][param][value_index] = w
                self.m_value[item][param][value_index] = m
                self.v_value[item][param][value_index] = v

        return

    def policy_update(self, last_state, last_action, step_size, discount, estimate, target):
        # parameterized policy update

        # calculate gradient and then update weights
        delta = target - estimate
        policy_gradient = self.get_softmax_policy_gradient(state=last_state,
                                                           action=last_action,
                                                           weights=self.weights_policy,
                                                           gradient_type=self.policy_type,
                                                           weight_type='policy')

        # update policy weights
        if self.policy_type == 'linear':
            action_index = slice(self.actions.index(last_action), self.actions.index(last_action) + 1)

        for item in self.weights_policy.keys():
            for param in self.weights_policy[item].keys():
                w, m, v, = self.update_weights(weights_type='policy',
                                               weights=self.weights_policy[item][param][action_index],
                                               step_size=step_size['policy'],
                                               delta=delta,
                                               gradient=policy_gradient[item][param][action_index],
                                               policy_discount=self.policy_discount,
                                               m=self.m_policy[item][param][action_index],
                                               v=self.v_policy[item][param][action_index],
                                               update_type=self.policy_update_type)

                self.weights_policy[item][param][action_index] = w
                self.m_policy[item][param][action_index] = m
                self.v_policy[item][param][action_index] = v

        if not self.use_avg_reward:
            self.policy_discount = discount * self.policy_discount

        return

    def agent_start(self, init_state, init_action):
        # start episode

        # reset policy discount
        self.policy_discount = 1

        # choose initial action:
        if pd.isnull(init_action):
            if pd.isnull(init_state):
                action = np.random.choice(self.actions)
            else:
                action = self.choose_action_from_policy(init_state)
        else:
            action = init_action

        return action

    def agent_step(self, last_state, last_action, state, raw_reward, terminal, epsilon, step_size, discount):
        # agent step during episode

        # choose action for SARSA, Q-learning, and Expected SARSA
        if self.agent_type in ['q_learning', 'sarsa', 'expected_sarsa']:
            action = self.choose_action_from_policy(state, epsilon=epsilon)
        else:
            action = None

        # calculate CVaR reward
        if self.use_cvar:
            reward = self.var_reward - (1 / self.var_quantile) * max(0, self.var_reward - raw_reward)
            if raw_reward >= self.var_reward:
                reward = reward - (self.var_quantile - 0) - (1 - self.var_quantile) * ((1 - self.var_quantile) - 1)
            else:
                reward = reward - (self.var_quantile - 1) - (1 - self.var_quantile) * ((1 - self.var_quantile) - 0)
        else:
            reward = raw_reward

        # calculate value target and estimate
        if self.agent_type == 'td':
            # get td value target and estimate
            target = self.get_td_target(state, reward, terminal, discount)
            estimate = self.get_value(last_state)

        elif self.agent_type == 'sarsa':
            # get sarsa value target and estimate
            target = self.get_sarsa_target(state, action, reward, terminal, discount)
            estimate = self.get_value(last_state, last_action)

        elif self.agent_type == 'q_learning':
            # get Q-learning value target and estimate
            target, argmax_action = self.get_q_learning_target(state, reward, terminal, discount)
            estimate = self.get_value(last_state, last_action)

        elif self.agent_type == 'expected_sarsa':
            # get expected sarsa target and estimate
            target = self.get_expected_sarsa_target(state, reward, terminal, discount, epsilon)
            estimate = self.get_value(last_state, last_action)

        # update average reward
        if self.use_avg_reward:
            self.avg_reward += step_size['avg_reward'] * (target - estimate)

        # update VAR estimate
        if self.use_cvar:
            if raw_reward >= self.var_reward:
                self.var_reward -= step_size['var'] * (target - estimate)
            else:
                self.var_reward += step_size['var'] * (self.var_quantile / (1 - self.var_quantile)) * (target - estimate)

        # update values
        if self.value_type == 'tabular':
            # tabular case
            self.value_tabular_update(last_state=last_state, last_action=last_action, step_size=step_size,
                                      estimate=estimate, target=target)
        else:
            # value function approximation
            self.value_approx_update(last_state=last_state, last_action=last_action, step_size=step_size,
                                     estimate=estimate, target=target)

        # update parameterized policy
        if self.policy_type not in ['tabular', 'softmax']:
            self.policy_update(last_state, last_action, step_size, discount, estimate, target)

        # update beta_m_product and beta_v_product
        if self.value_update_type == 'adam' or self.policy_update_type == 'adam':
            self.beta_m_product *= self.beta_m
            self.beta_v_product *= self.beta_v

        # choose action if we have not done so already
        if pd.isnull(action):
            action = self.choose_action_from_policy(state, epsilon=epsilon)

        return action

    def get_td_target(self, state, reward, terminal, discount):
        # TD(lambda) update
        if not terminal:
            if self.use_avg_reward:
                target = reward - self.avg_reward + self.get_target_value(state)
            else:
                target = reward + discount * self.get_target_value(state)
        else:
            target = reward

        return target

    def get_sarsa_target(self, state, action, reward, terminal, discount):
        # SARSA update
        if not terminal:
            if self.use_avg_reward:
                target = reward - self.avg_reward + self.get_target_value(state, action)
            else:
                target = reward + discount * self.get_target_value(state, action)
        else:
            target = reward

        return target

    def get_q_learning_target(self, state, reward, terminal, discount):
        # Q-Learning update
        if self.value_type == 'tabular':
            values = self.values[state]
        else:
            values = {}
            for a in self.actions:
                values[a] = self.get_target_value(state, a)

        argmax_action = self.argmax(values)

        if not terminal:
            if self.use_avg_reward:
                target = reward - self.avg_reward + values[argmax_action]
            else:
                target = reward + discount * values[argmax_action]
        else:
            target = reward

        return target, argmax_action

    def get_expected_sarsa_target(self, state, reward, terminal, discount, epsilon):
        # Expected SARSA update
        expected_value = 0

        if self.value_type == 'tabular':
            values = self.values[state]
        else:
            values = {}
            for a in self.actions:
                values[a] = self.get_target_value(state, a)

        argmax_action = self.argmax(values)

        if self.policy_type != 'tabular':
            softmax_probs = self.get_softmax_probabilities(state)
        else:
            softmax_probs = None

        for a in self.actions:
            expected_value += self.get_policy(state, a, argmax_action=argmax_action, epsilon=epsilon, softmax_probs=softmax_probs) * values[a]

        if not terminal:
            if self.use_avg_reward:
                target = reward - self.avg_reward + expected_value
            else:
                target = reward + discount * expected_value
        else:
            target = reward

        return target
