
import numpy as np
import loss


#%%
class perceptron:
    def __init__(self, dim, K, PA = False):
        self.name = "perceptron"
        if PA:
            self.name = "Pass-Aggr"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.PA = PA
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        return(ystart)
    
    def update(self, y, x):
        Wt = self.Wmat()
        grad = loss.hinge(beta = 1).gradient(y, x, Wt)
        ystart = np.argmax(np.dot(x, Wt))
        if self.PA:
            lt = loss.hinge(beta = 1).loss(y, x, Wt)
            xnorm2 = np.dot(x, x)
            self.weights = self.weights - lt/xnorm2 * grad
        if ystart != y:
            self.weights = self.weights - grad


            
#%%

class soperceptron:
    def __init__(self, dim, K, alpha = 1):
        self.name = "soperceptron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = 1 / alpha * np.diag(np.repeat(1, dim * K))
        
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        return(ystart)
    
    def update(self, y, x):
        Wt = self.Wmat()
        grad = loss.hinge(beta = 1).gradient(y, x, Wt)
        ystart = np.argmax(np.dot(x, Wt))
        if ystart != y:
            q = np.matmul(self.mat, grad)    
            self.mat = self.mat - np.outer(q, q)/(1 + np.dot(q, grad))
            self.weights = self.weights - np.matmul(self.mat, grad)
            



#%%

class sodiagperceptron:
    def __init__(self, dim, K, alpha = 1.0):
        self.name = "soperceptron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = alpha * np.ones(dim * K)
        
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        return(ystart)
    
    def update(self, y, x):
        Wt = self.Wmat()
        grad = loss.hinge(beta = 1).gradient(y, x, Wt)
        ystart = np.argmax(np.dot(x, Wt))
        if ystart != y:
            self.mat = self.mat + grad ** 2
            self.weights = self.weights - np.multiply(1/self.mat, grad)
  

#%%

class Banditron:
    def __init__(self, dim, K, gamma = 0):
        self.name = "Banditron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.outcomes = list(range(K))
        self.gamma = gamma
        self.outcomes = list(range(self.K))
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - self.gamma) * eystart + self.gamma * 1/self.K * np.ones(self.K)
        ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(ytprime)
    
    def update(self, y, x):
        Wmat = self.Wmat()
        tildek = np.argmax(np.dot(x, Wmat))
        etildek = np.zeros(self.K)
        etildek[tildek] = 1
        if y == (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            ey = np.zeros(self.K)
        else:
            pt = self.ptprime[y]
            ey = np.zeros(self.K)
            ey[y] = 1/pt
                    
        grad = np.kron(x, tildek - ey)
        self.weights = self.weights - grad
        
#%%

class IWBanditron:
    def __init__(self, dim, K, gamma = 0):
        self.name = "IWBanditron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.outcomes = list(range(K))
        self.gamma = gamma
        self.outcomes = list(range(self.K))
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - self.gamma) * eystart + self.gamma * 1/self.K * np.ones(self.K)
        ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(ytprime)
    
    def update(self, y, x):
        Wmat = self.Wmat()
        tildek = np.argmax(np.dot(x, Wmat))
        etildek = np.zeros(self.K)
        etildek[tildek] = 1
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            pt = self.ptprime[y]
            grad = 1/pt * loss.hinge(beta = 1).gradient(y, x, Wmat)
            self.weights = self.weights - grad

        
#%%

class RABanditron:
    def __init__(self, dim, K, reveal, gamma = 0):
        self.name = "RABanditron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.outcomes = list(range(K))
        self.gamma = gamma
        self.outcomes = list(range(self.K))
        self.reveal = reveal
        zoos = np.zeros(K)
        zoos[reveal] = 1
        self.revealvec = zoos
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - self.gamma) * eystart + self.gamma * self.revealvec
        ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(ytprime)
    
    def update(self, y, x):
        Wmat = self.Wmat()
        tildek = np.argmax(np.dot(x, Wmat))
        etildek = np.zeros(self.K)
        etildek[tildek] = 1
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            if y == self.reveal:
                pt = self.gamma
            else:
                pt = self.ptprime[y] + self.gamma
            grad = 1/pt * loss.hinge(beta = 1).gradient(y, x, Wmat)
            self.weights = self.weights - grad
        
#%% 

 
class Newtron:
    def __init__(self, dim, K, gamma = 0, D = 1, alpha = 1, beta = 1, FI = False, project = False, greedy = False):
        self.name = "Newtron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = 1/D * np.diag(np.repeat(1, dim * K))
        self.beta = beta
        self.alpha = alpha
        self.gamma = gamma
        self.FI = FI
        self.bt = np.zeros(dim * K)
        self.project = project
        self.outcomes = list(range(self.K))
        self.greedy = greedy
        self.D = D
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
        
    def predict(self, x):
        Wt = self.Wmat()
        self.pt = loss.logistic().softmaxvec(x, Wt, alpha = self.alpha)
        self.ptprime = (1 - self.gamma) * self.pt + self.gamma/self.K * np.ones(self.K)
        ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        self.ytprime = int(ytprime)
        if self.greedy: # only use in the full info setting.
            return(np.argmax(self.ptprime))
        return(ytprime)
    
    def update(self, y, x):
        y = int(y)
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            if self.FI:
                kappat = 1
            else:
                kappat = self.ptprime[y]
            ey = np.zeros(self.K)
            ey[y] = 1
            grad = (1 - self.pt[y]) / self.ptprime[y] * np.kron(1/self.K * np.ones(self.K) - ey, x)
        else:
            kappat = 1
            eytprime = np.zeros(self.K)
            eytprime[self.ytprime] = 1
            grad = self.pt[self.ytprime] / self.ptprime[self.ytprime] * np.kron(eytprime - 1/self.K * np.ones(self.K), x)
        gradprime = np.sqrt(kappat * self.beta) * grad
        q = np.dot(self.mat, gradprime)
        self.mat = self.mat - np.outer(q, q)/(1 + np.dot(q, gradprime))
        self.bt = self.bt + (1 - kappat * self.beta * np.dot(grad, self.weights)) * grad
        weightsprime = -np.matmul(self.mat, self.bt)
        if self.project:
            self.weights = self.D * weightsprime / np.sqrt(np.dot(weightsprime, weightsprime))
        else:
            self.weights = weightsprime
            
#%%
class PNewtron:
    def __init__(self, dim, K, gamma = 0, D = 1, alpha = 1, beta = 1, FI = False, project = False, greedy = False):
        self.name = "PNewtron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = D * np.ones(dim * K)
        self.beta = beta
        self.alpha = alpha
        self.gamma = gamma
        self.FI = FI
        self.bt = np.zeros(dim * K)
        self.project = project
        self.outcomes = list(range(self.K))
        self.greedy = greedy
        self.D = D
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
        
    def predict(self, x):
        Wt = self.Wmat()
        self.pt = loss.logistic().softmaxvec(x, Wt, alpha = self.alpha)
        self.ptprime = (1 - self.gamma) * self.pt + self.gamma/self.K * np.ones(self.K)
        ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        self.ytprime = int(ytprime)
        if self.greedy: # only use in the full info setting.
            return(np.argmax(self.ptprime))
        return(ytprime)
    
    def update(self, y, x):
        y = int(y)
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            if self.FI:
                kappat = 1
            else:
                kappat = self.ptprime[y]
            ey = np.zeros(self.K)
            ey[y] = 1
            grad = (1 - self.pt[y]) / self.ptprime[y] * np.kron(1/self.K * np.ones(self.K) - ey, x)
        else:
            kappat = 1
            eytprime = np.zeros(self.K)
            eytprime[self.ytprime] = 1
            grad = self.pt[self.ytprime] / self.ptprime[self.ytprime] * np.kron(eytprime - 1/self.K * np.ones(self.K), x)
        self.mat = self.mat + kappat * self.beta * grad ** 2
        self.bt = self.bt + (1 - kappat * self.beta * np.dot(grad, self.weights)) * grad
        weightsprime = -np.multiply(1 / self.mat, self.bt)
        if self.project:
            self.weights = self.D * weightsprime / np.sqrt(np.dot(weightsprime, weightsprime))
        else:
            self.weights = weightsprime


            
#%% 

class SOBA:
    def __init__(self, dim, K, gamma = 0, alpha = 1, gamma_adaptive = False):
        self.name = "SOBA"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = 1 / alpha * np.diag(np.repeat(1, dim * K))
        self.gamma = gamma
        self.gamma_adaptive = gamma_adaptive
        if self.gamma_adaptive:
            self.gamma = 1
        self.mtotal = 0
        self.gradsum = np.zeros(dim * K)
        self.dotsum = 1
        self.tp1 = 1
        self.outcomes = list(range(self.K))
        
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))

    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - self.gamma) * eystart + self.gamma * 1/self.K * np.ones(self.K)
        ystart = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(ystart)
    
    def update(self, y, x):
        if self.gamma_adaptive:
            self.tp1 = self.tp1 + 1
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            Wt = self.Wmat()
            Pt = self.ptprime[y]
            grad = 1/Pt * loss.hinge(beta = 1).gradient(y, x, Wt)
            zt = np.sqrt(Pt) * grad
            q = np.matmul(self.mat, zt) 
            mt = (np.dot(self.weights, zt) ** 2 + 2 * np.dot(self.weights, grad))/(1 + np.inner(q, zt))
            if mt + self.mtotal >= 0:
                self.mtotal = self.mtotal + mt
                self.mat = self.mat - np.outer(q, q)/(1 + np.dot(q, zt))
                self.gradsum = self.gradsum - grad
                self.weights = np.matmul(self.mat, self.gradsum)
                if self.gamma_adaptive:
                    q = np.matmul(self.mat, zt)
                    self.dotsum = self.dotsum + np.dot(q, zt)
                    self.gamma = np.min([np.sqrt(self.K * self.dotsum / self.tp1), 0.5])


#%%
class SOBAdiag:
    def __init__(self, dim, K, gamma = 0, alpha = 1, gamma_adaptive = False):
        self.name = "SOBAdiag"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.mat = alpha * np.ones(dim * K)
        self.gamma = gamma
        self.gamma_adaptive = gamma_adaptive
        if self.gamma_adaptive:
            self.gamma = 1
        self.mtotal = 0
        self.gradsum = np.zeros(dim * K)
        self.dotsum = 1
        self.tp1 = 1
        self.outcomes = list(range(self.K))
        
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))

    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - self.gamma) * eystart + self.gamma * 1/self.K * np.ones(self.K)
        ystart = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(ystart)
    
    def update(self, y, x):
        if self.gamma_adaptive:
            self.tp1 = self.tp1 + 1
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            Wt = self.Wmat()
            Pt = self.ptprime[y]
            grad = 1/Pt * loss.hinge(beta = 1).gradient(y, x, Wt)
            zt = np.sqrt(Pt) * grad
            q = np.multiply(1/self.mat, zt) 
            mt = (np.dot(self.weights, zt) ** 2 + 2 * np.dot(self.weights, grad))/(1 + np.inner(q, zt))
            if mt + self.mtotal >= 0:
                self.mtotal = self.mtotal + mt
                self.mat = self.mat + zt ** 2
                self.gradsum = self.gradsum - grad
                self.weights = np.multiply(1/self.mat, self.gradsum)
                if self.gamma_adaptive:
                    q = np.multiply(1/self.mat, zt)
                    self.dotsum = self.dotsum + np.dot(q, zt)
                    self.gamma = np.min([np.sqrt(self.K * self.dotsum / self.tp1), 0.5])
  
    
  
  #%%  
class selPerceptron:
    def __init__(self, dim, K, gamma = 0, gamma_adaptive = False, beta = 1):
        self.name = "selPerceptron"
        self.dim = dim
        self.K = K
        self.weights = np.zeros(dim * K)
        self.outcomes = list(range(K))
        self.gamma = gamma
        self.outcomes = list(range(self.K))
        self.gada = gamma_adaptive
        self.Xt = 0
        self.beta = beta
        self.Kt = 0
    
    def Wmat(self):
        return(self.weights.reshape((self.dim, self.K), order = 'F'))
    
    def predict(self, x):
        Wt = self.Wmat()
        ystart = np.argmax(np.dot(x, Wt))
        margin = loss.hinge(1).m(Wt, x, ystart)
        if self.gada:
            self.Xprime = np.max([self.Xt, np.sqrt(np.inner(x, x))])
            self.gamma = self.beta * self.Xprime**2 * np.sqrt(1 + self.Kt)
        pt = self.gamma/(self.gamma + margin)
        ytprime = np.random.choice([ystart, self.K], 1, p = [1- pt, pt]) # use K to signal that I want to see the true label.
        return(ytprime)
    
    def update(self, y, x):
        Wmat = self.Wmat()
        ystart = np.argmax(np.dot(x, Wmat))
        if y != ystart:
            if self.gada:
                self.Xt = self.Xprime
                self.Kt = self.Kt + 1
            grad = loss.hinge(1).gradient(y, x, Wmat)
            self.weights = self.weights - grad
        