import numpy as np

class OPT_OJGD:
    def projection2simplex(y):
        p = len(y)
        sorted_y = np.flip(np.sort(y), axis=0)
        tmpsum = 0.0
        tmax_f = (np.sum(y) - 1.0)/p
        for i in range(p-1):
            tmpsum+= sorted_y[i]
            tmax = (tmpsum - 1)/ (i+1.0)
            if tmax > sorted_y[i+1]:
                tmax_f = tmax
                break
        
        
        return np.maximum(y - tmax_f, np.zeros(y.shape))

    def ojgd(lam, lam1, etad, F, alp):
        r"""Online Joint Gradient Descent (OJGD).

        Args:
            lam1: initial weights vector;
            etad: weights stepsize;
            alp:  trade-off factor.
        """
        if np.linalg.norm(lam1-lam)<=0.0001:
            term2 = np.zeros(lam.shape)
        else:
            term2 = (lam1-lam)/np.linalg.norm(lam1-lam)
        dg = F+alp*term2
        lam += etad*dg

        lam = OPT_OJGD.projection2simplex(lam)
        return lam