import sys
import os
sys.path.append(os.getcwd())
import numpy as np
from pymanopt.manifolds import PositiveDefinite
from online_problem import OnlineProblem 
from solver.offline_solver import OfflineSolver
from lib.function import frechet_mean
import config

n=config.n
T=config.T
block=config.block
foldname = config.foldname

SPD = PositiveDefinite(n, k=1)
A = np.load( foldname + 'data_A.npy' )

X_0 = np.eye(n)

os_problem = OnlineProblem(     mfd = SPD,
                                data = A,
                                time = T,
                                loss = frechet_mean.func,
                                grad = frechet_mean.grad,
                                diameter = config.diameter,
                                lipschitz= config.lipschitz,
                                curvature= config.n,
                                bound = config.bound,
                                mu = 0,
                                _sum_f = frechet_mean.sum_f,
                                _sum_grad = frechet_mean.sum_grad
                                ) 

# solve with grid for fast solutions
#grid_len = config.len
#list_T= list(range(0,T,grid_len))
#list_T.append(T-1)
#or for every t
list_T = list(range(T))

solver = OfflineSolver(mingrad = 1e-3)
solver.optimize(os_problem,X_0,list_T)

np.save( foldname +'data_offline',solver.offline_histories)
np.save( foldname + 'list_T',list_T)
print('offline solver completed')