from pymanopt import Problem
from pymanopt.solvers import *
import numpy as np
class OfflineSolver():
    def __init__(self, type='CG', mingrad = 10e-6,maxtime = 10000) -> None:
        if type == 'CG':
            self.solver = ConjugateGradient(mingradnorm = mingrad,maxtime = 10000)
        else:
            self.solver = SteepestDescent(mingradnorm = mingrad,maxtime = 10000)
            print('aa')
        
    def optimize(self,ol_problem,X_0,list_T):
        self.list_T = list_T
        length = len(list_T)
        self.offline_histories = np.zeros( length )
        self.offline_solver(ol_problem,X_0,list_T)
        
    def offline_solver(self,ol_problem,X_0,list_T):
        length = len(list_T)
        for i in range(length):
            t = self.list_T[i]
            print('offline round:',t)

            if ol_problem._sum_f:
                def func(X):
                    return ol_problem._sum_f(ol_problem.data[:(t+1)],X)
            else:
                func = lambda X: (1/(t+1)) * ol_problem.sum_f(t,X)
            
            if ol_problem._sum_grad:
                def grad(X):
                    return  ol_problem._sum_grad(ol_problem.data[:(t+1)],X)
            else:
                grad = lambda X: (1/(t+1)) * ol_problem.sum_grad(t,X)
            
            problem = Problem(manifold=ol_problem.mfd, cost=func, grad=grad)
            Xopt = self.solver.solve(problem,x=X_0)
            print('value',func(Xopt))
            self.offline_histories[i]= func(Xopt)