# from libmab.visualization import plotci, bar_plot
import numpy as np
import abc


class Learner:
    def __init__(self, n_arms: int) -> None:
        self.n_arms = n_arms
        self.t = 0
        self.arm_pulls = np.zeros(n_arms)
        self.estimates = np.zeros(n_arms)
        self.rewards = []

    def update(self, reward, arm) -> None:
        self.rewards.append(reward)
        self.estimates[arm] = (self.estimates[arm] * self.arm_pulls[arm] + reward) / (
            self.arm_pulls[arm] + 1
        )
        self.arm_pulls[arm] += 1
        self.t += 1

    @abc.abstractclassmethod
    def pull_arm(self) -> int:
        pass

class UCB(Learner):
    def __init__(self, n_arms: int, sigma: float = 1) -> None:
        super().__init__(n_arms)
        self.sigma = sigma

    def pull_arm(self) -> int:
        if self.t < self.n_arms:
            return self.t
        else:
            exploration = np.log(self.t) / self.arm_pulls
            exploration = 3 * self.sigma * np.sqrt(exploration)
            sel = np.add(self.estimates, exploration)
            return np.argmax(sel)