import numpy as np
from math import log
from BanditAlgorithm import BanditAlgorithm


class D_UCB(BanditAlgorithm):

    def __init__(
        self,
        num_actions: int,
        horizon: int,
        gamma: float = 0.99,       
        alpha: float = 2.0,        
        reward_min: float = 0.0,
        reward_max: float = 1.0,
    ):
        super().__init__(num_actions, horizon)

        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.reward_min = float(reward_min)
        self.reward_max = float(reward_max)

        self.init_params = {
            "num_actions": num_actions,
            "horizon": horizon,
            "gamma": gamma,
            "alpha": alpha,
            "reward_min": reward_min,
            "reward_max": reward_max,
        }

        self.S = np.zeros(self.num_actions, dtype=float)  
        self.N = np.zeros(self.num_actions, dtype=float)  
        self.means = np.zeros(self.num_actions, dtype=float)

        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0 for i in range(self.num_actions)}
        self.chosen_arm = 0

        self.all_arms = None
        self.arms = None

    def select_arm(self, arms):
        self.all_arms = np.array(arms) if arms is not None else None
        self.arms = self.all_arms

        zero_arms = np.where(self.N <= 0.0)[0]
        if zero_arms.size > 0:
            return int(np.random.choice(zero_arms))

        total_N = float(self.N.sum())
        self.means = np.divide(self.S, self.N, out=np.zeros_like(self.S), where=self.N > 0.0)

        bonus = np.sqrt(
            (self.alpha * max(0.0, log(1.0 + total_N))) / np.maximum(self.N, 1e-12)
        )
        ucb = self.means + bonus

        mixer = np.random.random(ucb.size)
        ucb_indices = np.lexsort((mixer, ucb))
        chosen_arm = int(ucb_indices[-1])
        return chosen_arm

    def update_statistics(self, x, y):
        self.chosen_arm = int(x)

        r = float(y)

        self.S *= self.gamma
        self.N *= self.gamma
        self.S[x] += r
        self.N[x] += 1.0

        self.TotalNumber[x] += 1
        self.TotalSum[x] += r
        self.SUMS[x].append(r)

    def reset(self):
        self.__init__(**self.init_params)

    def re_init(self):
        self.reset()

    def __str__(self):
        return "D-UCB"