import os
import sys
sys.path.append(os.getcwd())

import numpy as np
from lib.function import frechet_mean_h
from online_problem import OnlineProblem
from solver.online_bandit import OnlineBandit

import config

n=config.n
T=config.T
block=config.block
foldname = config.foldname
Hn =config.Hn


A = np.load(foldname + 'data_A.npy')
X_0 = config.X_0

ol_fre_prob = OnlineProblem(     mfd = Hn,
                                data = A,
                                time = T,
                                loss = frechet_mean_h.func,
                                grad = frechet_mean_h.grad,
                                diameter = config.diameter,
                                lipschitz= config.lipschitz,
                                curvature= config.curvature,
                                bound = config.bound,
                                mu = 0,
                                _sum_f = frechet_mean_h.sum_f,    
                                _sum_grad = frechet_mean_h.sum_grad
                                ) 

aver_values = np.zeros(  T   )
aver_time = np.zeros(  T  )
solver = OnlineBandit()
rounds = 100
for i in range(rounds):
    solver.optimize(ol_fre_prob, X_0, mul = 4)
    solver.calculate_aver_value()
    solver.sum_time()
    aver_values += solver.aver_value_histories
    aver_time   += solver.time_sum
aver_values = aver_values / rounds
aver_time = aver_time /rounds

np.save( foldname + 'data_bandit',aver_values)
np.save( foldname + 'time_bandit',aver_time)
print('bandit solver completed')
