import numpy as np
from tqdm import tqdm
from alg import *
from mlsr_env import *
import matplotlib.pyplot as plt

predata = np.load('mslrdataprocessed.npz')
precontextall = data['features'].reshape(-1, 136)[:,]
prerewardsall = data['relevances'].reshape(-1, 1)[:]
precontext = data['features'].reshape(-1, 136)[:10,]
prerewards = data['relevances'].reshape(-1, 1)[:10]

T = 5000
bg = 30 #30, 100, 600
env = mlsr(budget= bg)
budget = env.budget/T
alg1 = AUPD(K = 20, dim = 136, T = T, b = budget)
alg2 = SquareCBwK(K = 20, dim = 136, T = T, b = budget)
alg3 = PGD(K = 20, dim = 136, T = T, b = budget)
rmodel = gb5(depth=5)
rmodel.pretrain(X_train=precontext,y_train=prerewards)
cmodel = meanvalue(K = 20)

for i in tqdm(range(0, T), desc='Running AUPD'):
    a = env.sendcontext()
    if env.stopping == 0:
        action = alg1.take_action(context = a, rmodel = rmodel, cmodel = cmodel)
        #add new
        reward, cost = env.feedback(action)
        alg1.update(context = a, action = action, rmodel = rmodel, cmodel =cmodel, reward = reward, cost = cost)
    else:
        print(i)
        env.remain(T - i -1)
        break

AUPDreward = env.returnresult()
rmodel.pretrain(X_train=precontext,y_train=prerewards)
env.reset()
cmodel.reset()

for i in tqdm(range(0, T), desc='Running PGD'):
    a = env.sendcontext()
    if env.stopping == 0:
        action = alg3.take_action(context = a, rmodel = rmodel, cmodel = cmodel)
        #add new
        reward, cost = env.feedback(action)
        alg3.update(context = a, action = action, rmodel = rmodel, cmodel =cmodel, reward = reward, cost = cost)
    else:
        print(i)
        env.remain(T - i -1)
        break

PGDreward = env.returnresult()
rmodel.pretrain(X_train=precontext,y_train=prerewards)
env.reset()
cmodel.reset()

for i in tqdm(range(0, T), desc='Running SquareCBwK'):
    a = env.sendcontext()
    if env.stopping == 0:
        action = alg2.take_action(context = a, rmodel = rmodel, cmodel = cmodel)
        #add new
        reward, cost = env.feedback(action)
        alg2.update(context = a, action = action, rmodel = rmodel, cmodel =cmodel, reward = reward, cost = cost)
    else:
        print(i)
        env.remain(T - i -1)
        break

Sqreward = env.returnresult()
rmodel.pretrain(X_train=precontext,y_train=prerewards)
env.reset()
cmodel.reset()


plt.plot(AUPDreward, label="AUPD")
plt.plot(PGDreward, label="PGD")
plt.plot(Sqreward, label="SquareCBwK")
plt.title("reward")
plt.legend()
plt.show()
