import numpy as np
from numpy.linalg import pinv
from src.core.bandit import BanditAlgorithm

class LB_WindowUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, w=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.w = int(w) if w is not None else None

        self.internal_t = 0
        self.theta_hat = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d)
        self.z = np.zeros(self.d)
        self.x_w = []
        self.y_w = []
        self.last_arms = None
        
        self.zeta_t = None
        self.params_set = self.w is not None
        if self.params_set: self._calc_zeta()

    def _auto_tune(self, P_T):
        tau_val = (self.d**0.25) * (self.T**0.5) * ((1 + P_T)**(-0.5))
        self.w = int(max(1, tau_val))
        self._calc_zeta()
        self.params_set = True

    def _calc_zeta(self):
        self.zeta_t = np.sqrt(self.r_lambda) * self.S + self.R * \
            np.sqrt(self.d * np.log((1 + self.w * self.L**2 / self.r_lambda) / self.delta))

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)

        x_num = len(arms)
        ucb_s = np.zeros(x_num)
        
        for i, x in enumerate(arms):
            ucb_s[i] = np.dot(x.T, self.theta_hat) + self.zeta_t * np.sqrt(np.dot(x.T, np.dot(self.inv_V, x)))
            
        mixer = np.random.random(ucb_s.size)
        ucb_indices = list(np.lexsort((mixer, ucb_s)))
        return ucb_indices[-1]

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        y = reward
        
        if self.internal_t < self.w:
            self.V = self.V + np.outer(x, x.T)
            self.z += y * x
            self.x_w.append(x)
            self.y_w.append(y)
            self.inv_V = pinv(self.V)
        else:
            aat = np.outer(x, x.T)
            act_delayed = self.x_w.pop(0)
            aat_delayed = np.outer(act_delayed, act_delayed.T)
            rew_delayed = self.y_w.pop(0)
            self.V = self.V + aat - aat_delayed
            self.z = self.z + y * x - rew_delayed * act_delayed
            self.x_w.append(x)
            self.y_w.append(y)
            self.inv_V = pinv(self.V)
            
        self.theta_hat = np.inner(self.inv_V, self.z)
        self.internal_t += 1

    def re_init(self):
        super().re_init()
        self.internal_t = 0
        self.theta_hat = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d)
        self.z = np.zeros(self.d)
        self.x_w = []
        self.y_w = []
        self.last_arms = None
        if self.w is None: self.params_set = False 

    def __str__(self):
        return f'LB-WindowUCB(w={self.w})'