import os
import sys

import numpy as np

sys.path.append(os.getcwd())

import numpy as np
from lib.function import oper_scal
from online_problem import OnlineProblem
from solver.online_bandit import OnlineBandit

import config

n=config.n
T=10000
block=config.block
foldname = config.foldname
SPD =config.SPD


A = np.load(foldname + 'data_A.npy')
X_0 = config.X_0
os_problem = OnlineProblem(     mfd = SPD,
                                data = A[:T],
                                time = T,
                                loss = oper_scal.func,
                                grad = oper_scal.grad,
                                diameter = config.diameter,
                                lipschitz= config.lipschitz,
                                curvature= config.curvature,
                                bound = config.bound
                                ) 



aver_values = np.zeros(  T   )
aver_time = np.zeros(  T  )
solver = OnlineBandit()
rounds = 100
for i in range(rounds):
    solver.optimize(os_problem, X_0, mul = 5)
    solver.calculate_aver_value()
    solver.sum_time()
    aver_values += solver.aver_value_histories
    aver_time   += solver.time_sum
aver_values = aver_values / rounds
aver_time = aver_time /rounds


from matplotlib import pyplot as plt
plt.plot(aver_values)
plt.show()

#np.save( foldname + 'data_bandit',aver_values)
#np.save( foldname + 'time_bandit',aver_time)
#print('bandit solver completed')