import numpy as np
from math import log


class Base_LB(object):
    def __init__(self, d, delta, r_lambda, S, L, R, s, e):

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.s = s
        self.e = e
        self.len = self.e - self.s +1

        self.c_1 = np.sqrt(self.r_lambda)*self.S
        self.c_2 = 2 * log(1 / self.delta)

        self.t = 0
        self.z = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d) 
        self.theta_hat = np.zeros(self.d)

    def select_arm(self, arms):

        assert type(arms) == list, 'List of arms as input required'
        self.theta_hat = np.inner(self.inv_V, self.z) 
        x_num = len(arms) 
        ucb_s = np.zeros(x_num) 
        beta = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log(1 + (self.t * self.L * self.L) /
                                                                   (self.r_lambda * self.d)))
        for (i,x) in enumerate(arms):
            ucb_s[i] = np.dot(x.T, self.theta_hat) + beta * np.sqrt(np.dot(x.T, np.dot(self.inv_V, x)))
        mixer = np.random.random(ucb_s.size)  
        ucb_indices = list(np.lexsort((mixer, ucb_s)))  
        output = ucb_indices[::-1] 
        chosen_arm = output[0]
        max_ub = ucb_s.max()
        return chosen_arm, max_ub, self.theta_hat

    def update_state(self,x, y):

        assert isinstance(x, np.ndarray), 'np.array required'
        self.V = self.V + np.outer(x, x.T)
        self.z = self.z + y * x
        den = 1 / (1 + np.dot(x.T, np.dot(self.inv_V, x)))
        num = np.dot(self.inv_V, x[:, np.newaxis])
        self.inv_V = self.inv_V - den * np.dot(num, num.T)
        self.t += 1

    def re_init(self):

        self.t = 0
        self.z = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d) 
        self.theta_hat = np.zeros(self.d)

    def __str__(self):
        return 'Base-LB'