import numpy as np
from math import log
from numpy.linalg import pinv
from src.core.bandit import BanditAlgorithm

class LB_DLinUCB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, d, delta, r_lambda, S, R, gamma=None):
        super().__init__(num_actions, horizon)
        self.d = d
        self.delta = delta
        self.r_lambda = r_lambda
        self.S = S
        self.R = R
        self.gamma = gamma
        self.c_delta = 2 * log(1 / self.delta)
        
        self.theta_hat = np.zeros(self.d)
        self.cov = self.r_lambda * np.identity(self.d)
        self.cov_squared = self.r_lambda * np.identity(self.d)
        self.invcov = 1 / self.r_lambda * np.identity(self.d)
        self.b = np.zeros(self.d)
        self.gamma2_t = 1.0
        self.last_arms = None
        self.params_set = self.gamma is not None

    def _auto_tune(self, P_T):
        val = np.sqrt(P_T / (self.d * self.T))
        self.gamma = 1.0 - max(1.0/self.T, val)
        self.params_set = True

    def select_arm(self, arms, pt=None, **kwargs):
        self.last_arms = arms
        if not self.params_set:
            self._auto_tune(pt if pt is not None else 0.0)

        x_num = len(arms)
        ucb_s = np.zeros(x_num)
        
        const1 = np.sqrt(self.r_lambda) * self.S
        term_log = 1 + (1 - self.gamma2_t) / (self.d * self.r_lambda * (1 - self.gamma**2 + 1e-12))
        beta_t = const1 + self.R * np.sqrt(self.c_delta + self.d * np.log(term_log))

        for i, a in enumerate(arms):
            invcov_a = np.inner(self.invcov @ self.cov_squared @ self.invcov, a.T)
            ucb_s[i] = np.dot(self.theta_hat, a) + beta_t * np.sqrt(np.dot(a, invcov_a))
            
        mixer = np.random.random(ucb_s.size)
        ucb_indices = list(np.lexsort((mixer, ucb_s)))
        return ucb_indices[-1]

    def update_statistics(self, arm, reward, **kwargs):
        x = self.last_arms[arm]
        aat = np.outer(x, x.T)
        self.gamma2_t *= self.gamma ** 2
        
        self.cov = self.gamma * self.cov + aat + (1 - self.gamma) * self.r_lambda * np.identity(self.d)
        self.cov_squared = self.gamma ** 2 * self.cov + aat + (1 - self.gamma**2) * self.r_lambda * np.identity(self.d)
        self.b = self.gamma * self.b + reward * x
        self.invcov = pinv(self.cov)
        self.theta_hat = np.inner(self.invcov, self.b)

    def re_init(self):
        super().re_init()
        self.theta_hat = np.zeros(self.d)
        self.cov = self.r_lambda * np.identity(self.d)
        self.invcov = 1 / self.r_lambda * np.identity(self.d)
        self.cov_squared = self.r_lambda * np.identity(self.d)
        self.b = np.zeros(self.d)
        self.gamma2_t = 1.0
        self.last_arms = None
        if self.gamma is None: self.params_set = False

    def __str__(self):
        return f'LB-DLinUCB(gamma={self.gamma:.4f})'