import numpy as np

#%%

class scalefreeMD:
    def __init__(self, regularizer, diameter, dim, K, lada = 1, eta = "ada", project = True, project_half=False):
        self.regularizer = regularizer
        self.diameter = diameter
        self.dim = dim
        self.K = K
        self.lada = lada
        self.weights = regularizer.startmin(dim * K)
        self.maxregularizer = 0.5 * diameter
        if regularizer.name == "EGpm":
            self.maxregularizer = diameter * np.log(2 * K * dim)
        self.inverseeta2 = 1e-8
        self.eta = eta
        self.name = "ada" + regularizer.name
        self.gradsum = np.zeros(2 * K * dim)
        self.project = project
        self.project_half = project_half
        
        
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
        
    def update(self, y, x, loss, scaling = 1):
        grad = scaling * loss.gradient(y, x, self.Wmat())
        normgrad = self.regularizer.norm(grad)
        theta = self.regularizer.congrad(self.weights)
        self.inverseeta2 = self.inverseeta2 + normgrad**2
        if self.eta == "constant":
            wtilde = self.regularizer.grad(theta - self.lada * grad)
        else:
            if self.regularizer.name != "EGpm":
                wtilde = self.regularizer.grad(theta - np.sqrt(self.lada*self.maxregularizer/self.inverseeta2) * grad)
            else:
                grad2 = np.concatenate((grad, - grad))
                self.gradsum = self.gradsum - np.sqrt(self.lada*self.maxregularizer/self.inverseeta2) * grad2
                wtilde = self.gradsum
        if self.project:
            self.weights = self.regularizer.project(wtilde, self.diameter, self.project_half)
        else:
            self.weights = wtilde
            if self.regularizer.name == "EGpm":
                self.weights = self.regularizer.noproject(wtilde)

#%%

class L2reg:
    def __init__(self):
        self.name = "L2"
        
    def startmin(self, dim):
        return(np.zeros(dim))
    
    def grad(self, theta):
        return(theta)
    
    def congrad(self, w):
        return(w)
    
    def norm(self, x):
        return(np.sqrt(np.dot(x, x)))
    
    def project(self, w, diameter, project_half=False):
        if project_half:
            if self.norm(w) <= 0.5 * diameter:
                return(w)
            else:
                return(w/self.norm(w) * 0.5 * diameter)
        else:
            if self.norm(w) <= diameter:
                return(w)
            else:
                return(w/self.norm(w) * diameter)
        
        
#%%
class EGpm:
    def __init__(self):
        self.name = "EGpm"
        
    def startmin(self, dim):
        return(np.zeros(dim))
    
    def grad(self, theta):
        return(theta)
    
    def congrad(self, w):
        return(w)
    
    def norm(self, x):
        return(np.max(np.abs(x)))
    
    def project(self, w, diameter):
        wstable = w - np.max(w)
        p = np.exp(wstable)/np.sum(np.exp(wstable))
        wnew = 0.5 * diameter * (p[0:int(0.5 * len(p))] - p[int(0.5 * len(p)):])
        return(wnew)
        
    def noproject(self, w):
        p = np.exp(w)
        return(p[0:int(0.5 * len(p))] - p[int(0.5 * len(p)):])
        