import numpy as np
from copy import deepcopy


def get_statistics(buffer):
    mean = [np.mean(x) for x in buffer]
    # var = []
    nks = [len(x) for x in buffer]
    # totalT = np.sum(nks)
    # print(nks,totalT)
    # bonus = [np.sqrt(2 * np.log(totalT) / np.maximum(nk, 1)) for nk in nks]
    return np.array(mean), np.array(nks)


def ts_sampling(mean, nks):
    alpha = mean * nks
    beta = nks - alpha
    samples = np.random.beta(alpha+1, beta+1)
    return samples


def update_statistics(mean, nks, action, reward):
    mean[action] = (mean[action] * nks[action] + reward) / (nks[action] + 1)
    nks[action] += 1
    return mean, nks


def get_bonus(nks):
    totalT = np.sum(nks)
    bonus = np.sqrt(np.log(totalT) / np.maximum(nks, 1))
    return bonus


def UCB(buffer, bandit, runtime):
    regret = []
    mean, nks = get_statistics(buffer)
    bonus = get_bonus(nks)
    for i in range(runtime):
        action = np.argmax(mean + bonus)
        reward = bandit.pull(action)
        regret.append(np.max(bandit.theta) - bandit.theta[action])
        mean, nks = update_statistics(mean, nks, action, reward)
        bonus = get_bonus(nks)

    return np.cumsum(np.array(regret))

def LCB(buffer, bandit, runtime):
    regret = []
    mean, nks = get_statistics(buffer)
    bonus = get_bonus(nks)
    for i in range(runtime):
        action = np.argmax(mean - bonus)
        reward = bandit.pull(action)
        regret.append(np.max(bandit.theta) - bandit.theta[action])
        mean, nks = update_statistics(mean, nks, action, reward)
        bonus = get_bonus(nks)

    return np.cumsum(np.array(regret))


def TS(buffer, bandit, runtime):
    regret = []
    mean, nks = get_statistics(buffer)
    for i in range(runtime):
        sampled_reward = ts_sampling(mean,nks)
        action = np.argmax(sampled_reward)
        reward = bandit.pull(action)
        regret.append(np.max(bandit.theta) - bandit.theta[action])
        mean, nks = update_statistics(mean, nks, action, reward)

    return np.cumsum(np.array(regret))


def TS_NOUPDATE(buffer, bandit, runtime):
    regret = []
    mean, nks = get_statistics(buffer)
    for i in range(runtime):
        sampled_reward = ts_sampling(mean,nks)
        action = np.argmax(sampled_reward)
        reward = bandit.pull(action)
        regret.append(np.max(bandit.theta) - bandit.theta[action])
        # mean, nks = update_statistics(mean, nks, action, reward)

    return np.cumsum(np.array(regret))