import numpy as np

class BernoulliBandit(object):
    def __init__(self, K, ps=None, seed=None):
        self.K = K
        if ps is None:
            self.ps = 0.25*np.ones(self.K)
            self.ps[0] = 0.75
        else:
            self.ps = ps
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()

    def get_rewards(self):
        return(self.rng.binomial(1,self.ps))

    def get_best(self):
        return np.argmax(self.ps)+1


class StackedBernoulliBandit(object):
    def __init__(self, K, T, ps=None, seed=None):
        self.K = K
        self.T = T
        if ps is None:
            self.ps = 0.25*np.ones(self.K)
            self.ps[0] = 0.75
        else:
            self.ps = ps
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()
        self.rewards = self._init_reward_matrix()
        self.ptrs = [0 for i in range(self.K)]

    def _init_reward_matrix(self):
        return self.rng.binomial(1, self.ps, (self.T,self.K))

    def get_reward(self,act):
        r = self.rewards[self.ptrs[act],act]
        self.ptrs[act] += 1
        return (r)
        
