import numpy as np
from math import log
from numpy.linalg import pinv
from src.core.bandit import BanditAlgorithm

class LB_RestartUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, L, R, H=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.L = L
        self.R = R
        self.H = int(H) if H is not None else None
        
        self.c_1 = np.sqrt(self.r_lambda) * self.S
        self.c_2 = 2 * log(1 / self.delta)
        
        self.internal_t = 1
        self.j = 1
        self.tau = 0
        self.theta_hat = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d)
        self.z = np.zeros(self.d)
        
        self.t_epoch = None
        self.params_set = self.H is not None
        self.last_arms = None

    def _auto_tune(self, P_T):
        tau_val = (self.d**0.25) * (self.T**0.5) * ((1 + P_T)**(-0.5))
        self.H = int(max(1, tau_val))
        self.t_epoch = self.tau + self.H - 1
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)
            
        if self.t_epoch is None: 
            self.t_epoch = self.tau + self.H - 1

        x_num = len(arms)
        ucb_s = np.zeros(x_num)
        
        val = 1 + ((self.internal_t - self.tau) * self.L**2) / (self.r_lambda * self.d)
        beta = self.c_1 + self.R * np.sqrt(self.c_2 + self.d * log(val))

        for i, x in enumerate(arms):
            ucb_s[i] = np.dot(x.T, self.theta_hat) + beta * np.sqrt(np.dot(x.T, np.dot(self.inv_V, x)))
            
        mixer = np.random.random(ucb_s.size)
        ucb_indices = list(np.lexsort((mixer, ucb_s)))
        return ucb_indices[-1]

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm] 

        self.V = self.V + np.outer(x, x.T)
        self.z = self.z + reward * x
        self.inv_V = pinv(self.V)
        self.theta_hat = np.inner(self.inv_V, self.z)
        self.internal_t += 1

        if self.internal_t > self.t_epoch:
            self.j += 1
            self.tau = (self.j - 1) * self.H
            self.internal_t = self.tau + 1
            self.V = self.r_lambda * np.identity(self.d)
            self.t_epoch = self.tau + self.H - 1
            self.inv_V = pinv(self.V)
            self.z = np.zeros(self.d)
            self.theta_hat = np.inner(self.inv_V, self.z)

    def re_init(self):
        super().re_init()
        self.internal_t = 1
        self.j = 1
        self.tau = 0
        self.theta_hat = np.zeros(self.d)
        self.V = self.r_lambda * np.identity(self.d)
        self.inv_V = 1 / self.r_lambda * np.identity(self.d)
        self.z = np.zeros(self.d)
        self.last_arms = None
        if self.H:
            self.t_epoch = self.tau + self.H - 1
        else:
            self.t_epoch = None
            self.params_set = False

    def __str__(self):
        return f'LB-RestartUCB(H={self.H})'