import numpy as np
from .online_solver import OnlineSolver
import time
class OnlineGradientDescent(OnlineSolver):
    def __init__(self) -> None:
       self.solver_type = 'OGD' 
       self.grad_norm_his = []

    def optimize(self,problem,X_0):
        T = problem.time
        track_list = ['X']
        self.initial_with_problem(T,X_0,track_list)
        self.X[0] = X_0
        if problem.mu > 0:
            self.gradient_solver_sc(problem,X_0)
        else:
            self.gradient_solver(problem,X_0)
        

    def gradient_solver(self,problem,X_0):
        T = problem.time
        D = problem.D
        L = problem.L
        zeta = problem.zeta
        mfd = problem.mfd
        eta = D/(L* zeta ** (1/2))
        for t in range(T):
            time_s = time.time()
            X_t = self.X[t]
            #suffer the loss
            value = problem.f(t,X_t)
            self.value_histories[t] = value
            eta_t = eta / ((t+1)**0.5)
            grad_t = problem.g(t,X_t)   #gradient
            self.grad_norm_his.append(np.linalg.norm(grad_t))
            X_t_plus_1 = mfd.exp(X_t, -eta_t * grad_t)
            if np.isnan(X_t_plus_1).any():
                raise ValueError
            self.X[t+1] =  X_t_plus_1
            time_e = time.time()
            self.time[t] = time_e-time_s
        

    def gradient_solver_sc(self,problem,X_0):
        T = problem.time
        D = problem.D
        L = problem.L
        zeta = problem.zeta
        mfd = problem.mfd
        mu = problem.mu

        eta = 1/mu
        for t in range(T):
            time_s = time.time()
            X_t = self.X[t]
            #suffer the loss
            value = problem.f(t,X_t)
            self.value_histories[t] = value
            #update
            eta_t = eta / (t+1)
            grad_t = problem.g(t,X_t)   #gradient
            self.grad_norm_his.append(np.linalg.norm(grad_t))
            X_t_plus_1 = mfd.exp(X_t, -eta_t * grad_t)
            if np.isnan(X_t_plus_1).any():
                raise ValueError
            self.X[t+1] =  X_t_plus_1
            time_e = time.time()
            self.time[t] = time_e-time_s
        
