# -*- coding: utf-8 -*-
"""Safe_LinearBandit.ipynb

"""
import numpy as np 
#################################UCB Class##############################################
class SafeBanditUCB():
    def __init__(self, feature, action_set, max_reward, epsilon, omega, L, tau, beta_1, beta_2, _lambda=1, time_step=1):
        self.feature = feature
        self.action_set = np.array([[0.1, 0], [0.8, 0], [1, 0], [0, 0.4], [0, 0.7]])
        self.max_reward = max_reward

        self.time_step = time_step
        self.epsilon = epsilon
        self.omega = omega
        self.L = L
        self.tau = tau
        self.beta_1 = beta_1  # Exploration parameter for UCB
        self.beta_2 = beta_2
        self._lambda = _lambda
        self.weights = np.zeros((feature.feature_dim, time_step))  #Reward weights

        self.gamma = np.zeros((feature.feature_dim, time_step))   #Safety weights

        self.gram_matrix = self._lambda * np.identity(feature.feature_dim)
        self.inv_gram_matrix = (1/self._lambda) * np.identity(feature.feature_dim)

        self.feature_reward_rec = np.zeros((feature.feature_dim, time_step)) 

        self.feature_cost_rec = np.zeros((feature.feature_dim, time_step)) 

    def calc_action_cost(self, phi_state_action):
        return phi_state_action @ self.gamma + self.beta_2 * np.sqrt(phi_state_action @ self.inv_gram_matrix @ phi_state_action)

    def select_action(self, state, k):

        actions = np.array([act for act in self.action_set if self.calc_action_cost(self.feature.transform(state, act)) <= self.tau])
        phi_state_actions = self.feature.transform(state, actions)

        def calc_bonus_term(k):
            norm_feature_gram_matrix = np.sqrt(np.einsum('ji,ik,kj->j', phi_state_actions, self.inv_gram_matrix, phi_state_actions.T))
            term_1 = self.beta_1 * norm_feature_gram_matrix
            _frac = self.tau / (self.tau + 2 * self.beta_2 * self.L * norm_feature_gram_matrix/np.linalg.norm(phi_state_actions, axis = 1))
            term_2 = 1-_frac
            gamma_uncertainty = norm_feature_gram_matrix/np.linalg.norm(phi_state_actions, axis = 1)
            return term_1 + np.max([1,(1-self.epsilon)/self.omega])*term_2 * self.max_reward, self.beta_2*gamma_uncertainty

        bonus, gamma_uncertainty = calc_bonus_term(k)
        ucb_values = np.einsum('ij,jk->i', phi_state_actions, self.weights) + bonus
        action = actions[np.argmax(ucb_values)]
        act_ind = np.where((self.action_set == action).all(axis=1))[0]
        return action, bonus, gamma_uncertainty, act_ind[0]

    def update(self, state, action, reward, cost, next_state, done):
        """

        Update Gram-matrix, gamma, and weights.

        """

        self.gram_matrix +=  np.outer(self.feature.transform(state, action), self.feature.transform(state, action)) # Updating Gram-matrix
        self.inv_gram_matrix = np.linalg.pinv(self.gram_matrix)

        self.feature_reward_rec +=  np.outer(self.feature.transform(state, action), reward) 
        self.weights = self.inv_gram_matrix @ self.feature_reward_rec

        self.feature_cost_rec +=  np.outer(self.feature.transform(state, action), cost) 
        self.gamma = self.inv_gram_matrix @ self.feature_cost_rec
########################################################################################