import numpy as np
from algorithm import *
import matplotlib

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

narms = 10
ndata = 1000
nseeds = 50
nlength = 100000


class Bandit:
    # k-arm bandit
    def __init__(self, k, noise=1):
        # self.theta = np.random.rand(k)
        self.theta = np.random.beta(100, 100, size=k)
        print('theta: ', self.theta)
        self.noise = noise
        self.k = k

    def pull(self, x):
        assert 0 <= x < self.k
        # return np.clip(self.theta[x] + self.noise * np.random.randn(1), 0, 1)
        return np.random.binomial(1, self.theta[x])


def create_offline_dataset(bandit, size):
    dataset = [[] for _ in range(bandit.k)]
    empty_dataset = [[] for _ in range(bandit.k)]
    for k in range(bandit.k):
        reward = bandit.pull(k)
        dataset[k].append(reward)
        empty_dataset[k].append(reward)
    for _ in range(size):
        action = np.random.randint(0, bandit.k)
        reward = bandit.pull(action)
        dataset[action].append(reward)
    # using random action
    return empty_dataset, dataset


total_regret_ucb = np.zeros(nlength)
total_regret_ucb_pure_online = np.zeros(nlength)
total_regret_lcb = np.zeros(nlength)
total_regret_lcb_pure_online = np.zeros(nlength)
total_regret_ts = np.zeros(nlength)
total_regret_ts_pure_online = np.zeros(nlength)
total_regret_ts_noupdate = np.zeros(nlength)
for seed in range(nseeds):
    print(seed)
    bandit = Bandit(narms)
    empty_dataset, dataset = create_offline_dataset(bandit, ndata)
    regret_ucb = UCB(dataset, bandit, nlength)
    regret_lcb = LCB(dataset, bandit, nlength)
    regret_ts = TS(dataset, bandit, nlength)
    # regret_ts_noupdate = TS_NOUPDATE(dataset, bandit, nlength)
    total_regret_ucb += regret_ucb
    total_regret_lcb += regret_lcb
    total_regret_ts += regret_ts
    # total_regret_ts_noupdate += regret_ts_noupdate
    
    print(total_regret_ucb.shape, total_regret_lcb.shape, total_regret_ts.shape)
    np.save(f'ucb_{str(seed)}', total_regret_ucb)
    np.save(f'lcb_{str(seed)}', total_regret_lcb)
    np.save(f'ts_{str(seed)}', total_regret_ts)
    
    # regret_ts_pure_online = TS(empty_dataset, bandit, nlength)
    # regret_ucb_pure_online = UCB(empty_dataset, bandit, nlength)
    # regret_lcb_pure_online = LCB(empty_dataset, bandit, nlength)
    # total_regret_ts_pure_online += regret_ts_pure_online
    # total_regret_ucb_pure_online += regret_ucb_pure_online
    # total_regret_lcb_pure_online += regret_lcb_pure_online

# x = np.arange(1,nlength+1)
# plt.plot(x, total_regret_ucb)
# plt.plot(x, total_regret_lcb)
# plt.plot(x, total_regret_ts)

# plt.plot(x,total_regret_ts_noupdate)
# plt.plot(x,total_regret_ucb_pure_online)
# plt.plot(x,total_regret_lcb_pure_online)
# plt.plot(x,total_regret_ts_pure_online)
# plt.legend() # using a size in points

# plt.legend(["ucb", "lcb", "ts", "ucb_pure online", "lcb_pure online", 'ts pure online'])
# plt.legend(["ucb", "lcb", "ts"])
# plt.rc('legend', fontsize=40)  # using a size in points

# plt.show()
