import numpy as np

#%%

class hinge:
    def __init__(self, beta = 0.5):
        self.beta = beta
        self.name = 'hinge'
    
    def m(self, Wmat, x, y):
        y = int(y)
        Kvec = np.dot(x, Wmat)
        if np.argmax(Kvec) == y:
            sub = np.sort(Kvec)[::-1][1]
        else:
            sub = np.max(Kvec)
        return(Kvec[y] - sub)
    
    def loss(self, y, x, Wmat):
        y = int(y)
        mt = self.m(Wmat, x, y)
        if mt <= self.beta:
            return(np.max([0, 1 - mt]))
        else:
            return(0)
    
    def gradient(self, y, x, Wmat):
        y = int(y)
        Kvec = np.dot(x, Wmat)
        if np.argmax(Kvec) == y:
            tildek = np.argsort(Kvec)[::-1][1]
            if self.m(Wmat, x, y) > self.beta:
                return(np.zeros(np.prod(Wmat.shape)))
        else:
            tildek = np.argmax(Kvec)
        etildek = np.zeros(Wmat.shape[1])
        etildek[tildek] = 1
        ey = np.zeros(Wmat.shape[1])
        ey[y] = 1
        return(np.kron(etildek - ey, x))
    
    def minloss(self, Wmat, x):
        ystart = np.argmax(np.dot(x, Wmat))
        return(self.loss(ystart, x, Wmat))
    
    
    
    
class smoothhinge:
    def __init__(self):
        self.name = 'smooth hinge'
    
    def m(self, Wmat, x, y):
        y = int(y)
        Kvec = np.dot(x, Wmat)
        if np.argmax(Kvec) == y:
            sub = np.sort(Kvec)[::-1][1]
        else:
            sub = np.max(Kvec)
        return(Kvec[y] - sub)
    
    def loss(self, y, x, Wmat):
        y = int(y)
        mt = self.m(Wmat, x, y)
        Kvec = np.dot(x, Wmat)
        if np.argmax(Kvec) == y:
            if mt >= 1:
                return(0)
            else:
                return((1 - mt)**2)
        else:
            return(1-2 * mt)
    
    def gradient(self, y, x, Wmat):
        y = int(y)
        Kvec = np.dot(x, Wmat)
        if np.argmax(Kvec) == y:
            tildek = np.argsort(Kvec)[::-1][1]
            mt = self.m(Wmat, x, y)
            mult = 2 * (1 - mt)
            if mt >= 1:
                return(np.zeros(np.prod(Wmat.shape)))
        else:
            tildek = np.argmax(Kvec)
            mult = 2
        etildek = np.zeros(Wmat.shape[1])
        etildek[tildek] = 1
        ey = np.zeros(Wmat.shape[1])
        ey[y] = 1
        return(mult * np.kron(etildek - ey, x))
    
    def minloss(self, Wmat, x):
        ystart = np.argmax(np.dot(x, Wmat))
        return(self.loss(ystart, x, Wmat))
    
    
    
class logistic:
    def __init__(self, base = 2):
        self.name = 'logistic'
        self.base = base
        
    def softmax(self, y, x, Wmat, alpha = 1):
        y = int(y)
        Kvec = alpha * np.dot(x, Wmat)
        Kvecstable = Kvec - np.max(Kvec)
        return(np.exp(Kvecstable[int(y)])/np.sum(np.exp(Kvecstable)))
    
    def softmaxvec(self, x, Wmat, alpha = 1):
        Kvec = alpha * np.dot(x, Wmat)
        Kvecstable = Kvec - np.max(Kvec)
        return(np.exp(Kvecstable)/np.sum(np.exp(Kvecstable)))
    
    def loss(self, y, x, Wmat, alpha = 1):
        y = int(y)
        return(-1/alpha * np.log(self.softmax(y, x, Wmat, alpha))/np.log(self.base))
    
    def gradient(self, y, x, Wmat, alpha = 1):
        y = int(y)
        ey = np.zeros(Wmat.shape[1])
        ey[y] = 1
        return(1/np.log(self.base) * np.kron(self.softmaxvec(x, Wmat, alpha) - ey, x))
    
    def minloss(self, Wmat, x, alpha = 1):
        ystart = np.argmax(np.dot(x, Wmat))
        return(self.loss(ystart, x, Wmat, alpha))
        