import os
import sys
sys.path.append(os.getcwd())

import numpy as np
from lib.function import frechet_mean_h
from online_problem import OnlineProblem
from solver.offline_solver import OfflineSolver

import config

sys.path.append(os.getcwd())
n=config.n
T=config.T
block=config.block
foldname = config.foldname
Hn = config.Hn
A = np.load( foldname + 'data_A.npy' )
X_0 = config.X_0




ol_problem = OnlineProblem(     mfd = Hn,
                                data = A,
                                time = T,
                                loss = frechet_mean_h.func,
                                grad = frechet_mean_h.grad,
                                diameter = config.diameter,
                                lipschitz= config.lipschitz,
                                curvature= config.n,
                                bound = config.bound,
                                mu = 0,
                                _sum_f = frechet_mean_h.sum_f,
                                _sum_grad = frechet_mean_h.sum_grad
                                ) 

solver = OfflineSolver(type='GD', mingrad = 1e-3)

# solve with grid for fast solutions
#grid_len = config.len
#list_T= list(range(0,T,grid_len))
#list_T.append(T-1)
#or for every t
list_T = list(range(T))

solver.optimize(ol_problem,X_0,list_T)
np.save( foldname +'data_offline',solver.offline_histories)
np.save( foldname + 'list_T',list_T)
print('offline solver completed')
