import numpy as np
from math import log
from numpy.linalg import pinv
from src.core.bandit import BanditAlgorithm

class LB_WeightUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R):
        super().__init__(num_actions, horizon)

        self.init_params = {
            'num_actions': num_actions,
            'horizon': horizon,
            'd': d,
            'delta': delta,
            'r_lambda': r_lambda,
            'S': S,
            'L': L,
            'R': R
        }

        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R

        self.c_1 = np.sqrt(self.r_lambda) * self.S
        self.c_2 = 2 * log(1 / self.delta)

        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.z = np.zeros(self.d)
        self.gamma_2t = 1
        self.theta_hat = np.zeros(self.d)

    def select_arm(self, arms, change_points):
        arms_array = np.array(arms)

        if isinstance(change_points, list):
            PT = len(change_points)
        else:
            PT = change_points

        self.all_arms = arms_array
        self.gamma = 1 - max(1 / self.T, (PT / (self.d * self.T))**0.5)

        self.theta_hat = self.inv_V @ self.z

        arms_inv_V = arms_array @ self.inv_V
        quadratic_forms = np.sum(arms_array * arms_inv_V, axis=1)
        beta_t = self.c_1 + self.R * np.sqrt(
            self.c_2 + self.d * np.log(1 + (1 - self.gamma_2t) * (self.L**2) / (self.d * self.r_lambda * (1 - self.gamma**2)))
        )

        ucb_s = arms_array @ self.theta_hat + beta_t * np.sqrt(quadratic_forms)

        mixer = np.random.random(ucb_s.size)
        ucb_indices = np.lexsort((mixer, ucb_s))
        chosen_arm = ucb_indices[-1]
        return chosen_arm

    def update_statistics(self, x, y):
        x_arm = self.all_arms[x]
        aat = np.outer(x_arm, x_arm)

        self.gamma_2t *= self.gamma**2
        self.V = self.gamma * self.V + aat + (1 - self.gamma) * self.r_lambda * np.identity(self.d)
        self.z = self.gamma * self.z + y * x_arm
        self.inv_V = pinv(self.V)

    def re_init(self):
        super().re_init()

        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = (1 / self.r_lambda) * np.identity(self.d)
        self.z = np.zeros(self.d)
        self.gamma_2t = 1

    def __str__(self):
        return 'LB-WeightUCB'
