import torch
import math

def batch_gen(loader):
    while True:
        for batch in loader:
            yield batch


class Discounted_UCB1_tuned():
    #ref: https://www.lri.fr/~sebag/Slides/Venice/Kocsis.pdf
    #note that the rewards are assumed to be in [0, 1]
    def __init__(self, arms, gamma = 0.99, eps = 1e-5, inf = 1e9):
        self.arms = arms
        self.gamma = gamma
        self.eps = eps
        self.inf = inf

        self.sum_rewards = {arm:0. for arm in self.arms}
        self.sum_tries = {arm:0. for arm in self.arms}
        self.total_tries = 0.


    def select(self):
        ret_score = -self.inf
        ret_idx = None
        
        for arm in self.arms:
            avg_reward = self.sum_rewards[arm] / max(self.sum_tries[arm], self.eps)
            confidence = max(avg_reward * (1 - avg_reward), 2e-3) * math.log(max(self.total_tries, 1.)) 
            confidence /= max(self.sum_tries[arm], self.eps)
            confidence = math.sqrt(max(confidence, self.eps))

            score = avg_reward + confidence

            if ret_score < score:
                ret_score = score
                ret_idx = arm
        
        return ret_idx


    def update(self, pulled, reward):
        self.total_tries = self.total_tries * self.gamma + 1.

        for arm in self.arms:
            self.sum_rewards[arm] *= self.gamma
            self.sum_tries[arm] *= self.gamma

        self.sum_rewards[pulled] += reward
        self.sum_tries[pulled] += 1

    
    def print(self):
        ret = ''
        for arm in self.arms:
            avg_reward = self.sum_rewards[arm] / max(self.sum_tries[arm], self.eps)
            ret = ret + '[' + str(arm) + ':' + ('%.4f]\n'%(avg_reward, ))

        print (ret)
