import numpy as np

class BBThompson(object):

    def __init__(self,num_arms=2,seed=None):
        ## initialize with Beta(1,1) prior for each arm
        self.prior = [ np.array([1,1]) for i in range(num_arms) ]
        self.counts = np.zeros((num_arms))
        self.rewards = np.zeros((num_arms))
        self.name = f"Thompson"
        self.k = num_arms
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()
        self.order = np.arange(0,self.k)
        self.rng.shuffle(self.order)

    def get_action(self):
        samples = [ self.rng.beta(self.prior[i][0], self.prior[i][1]) for i in range(self.k) ]
        return(np.argmax(samples))

    def update(self,action,reward):
        self.prior[action] += np.array([reward,1-reward])
        self.counts[action] += 1
        self.rewards[action] += reward

class UCB1(object):
    def __init__(self,num_arms=2,debug=False,seed=None,eta=1):
        self.counts = np.zeros((num_arms))
        self.rewards = np.zeros((num_arms))
        self.name = f"UCB1"
        self.debug=debug
        self.k = num_arms
        self.eta = eta

        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()
        self.order = np.arange(0,self.k)
        self.rng.shuffle(self.order)

    def get_action(self):
        for a in self.order:
            if self.counts[a] < 1:
                return(a)
        emp_means = self.rewards/self.counts
        bonus = 1.0/np.sqrt(self.counts)
        index = emp_means + self.eta*bonus
        action = self.rng.choice(np.flatnonzero(index == index.max()))
        if self.debug:
            print(f"[DEBUG] emp_means={emp_means}, bonus={bonus}, index={index}",flush=True)
        return action

    def update(self,action,reward):
        self.counts[action] += 1
        self.rewards[action] += reward

class Greedy(object):
    def __init__(self,num_arms=2,n0=1,seed=None):
        self.n0=n0
        self.rewards = np.zeros((num_arms))
        self.counts = np.zeros((num_arms))
        self.name = f"Greedy-{n0}"
        self.k = num_arms

        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()
        self.order = np.arange(0,self.k)
        self.rng.shuffle(self.order)

    def get_action(self):
        for a in self.order:
            if self.counts[a] < self.n0:
                return(a)
        emp_means = self.rewards/self.counts
        action = self.rng.choice(np.flatnonzero(emp_means == emp_means.max()))
        return action

    def update(self,action,reward):
        self.counts[action] += 1
        self.rewards[action] += reward

class EGreedy(object):
    def __init__(self,num_arms=2,eps=0.1,seed=None):
        self.eps=eps
        self.rewards = np.zeros((num_arms))
        self.counts = np.zeros((num_arms))
        self.name = f"eGreedy-{eps:0.2f}"
        self.k = num_arms

        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()

    def get_action(self):
        zeros = [i for i in range(self.k) if self.counts[i] == 0]
        if len(zeros) > 0:
            action = self.rng.choice(zeros)
        else:
            ber = self.rng.binomial(1,self.eps)
            if ber:
                action = self.rng.choice(self.k)
            else:
                emp_means = self.rewards/self.counts
                action = self.rng.choice(np.flatnonzero(emp_means == emp_means.max()))
        return (action)
                
    def update(self,action,reward):
        self.counts[action] += 1
        self.rewards[action] += reward
            
                            
class Unif(object):
    def __init__(self,num_arms=2,seed=None):
        self.k = num_arms
        self.counts = np.zeros((num_arms))
        self.name = f"unif"

        if seed is not None:
            self.rng = np.random.default_rng(seed)
        else:
            self.rng = np.random.default_rng()


    def get_action(self):
        act = self.rng.choice(self.k)
        return (act)

    def update(self,act,rew):
        self.counts[act] += 1
