import numpy as np
from lib.function.std_func import zeta

class OnlineProblem:
    def __init__(self,mfd,data,time,loss,grad,diameter,lipschitz,curvature,bound,mu=0,_sum_f=None,_sum_grad= None) -> None:

        if time != data.shape[0]:
            raise TypeError("data error: dimension not matched")
        self.mfd = mfd
        self.data = data
        self.loss = loss
        self.grad = grad
        self.time = time
        self.dim = int(mfd.dim)
        self.D = diameter
        self.L = lipschitz
        self.r = self.D
        self.kappa = curvature
        self.C = bound
        self.mu = mu
        self.zeta = zeta(self.kappa,self.D)
        self._sum_f = _sum_f
        self._sum_grad = _sum_grad

    def f(self,time,X):
        return self.loss(self.data[time],X)

    def g(self,time,X):
        return self.grad(self.data[time],X)
    
    def sum_f(self,time,X):
        value = 0
        for i in range(time+1):
            value = value + self.f(i,X)
        return value
    
    def sum_grad(self,time,X):
        ans = self.g(0,X)
        for i in range(time):
            ans = ans + self.g(i+1,X)
        return ans

    def random_grad(self,time,X):
        i = np.random.randint(time+1)
        grad = self.g(i,X)
        return grad