

import numpy as np
#%%

class Gappletron: 
    def __init__(self, gapmap, algW, gamma, beta, outcomes, domset = [], reveal = [], domsetdict = []):
        
        self.name = "Gappletron" + "-" + algW.name
        self.gapmap = gapmap 
        self.algW = algW
        self.gamma = gamma
        self.beta = beta
        self.inversegamma = 0
        self.outcomes = outcomes # The outcomes start with category 0 and end with category K-1 
        self.K = len(outcomes)
        if not domset: # if domset is an empty list the bandit scenario is assumed
            self.domset = outcomes 
            self.domsetvec = np.ones(self.K)/self.K
            self.domsetdict = []
        else:
            self.domset = domset
            self.domsetvec = np.zeros(self.K)
            self.domsetvec[domset] = 1.0
            self.domsetvec = self.domsetvec/np.sum(self.domsetvec)
            self.domsetdict = domsetdict # should be dict of form: {'0': revealed by, '1': revealed by .... etc}
        self.reveal = reveal # list of revealing actions
    
            
    

    def predict(self, x):
        Wt = self.algW.Wmat() # class vectors in columns
        ystart = np.argmax(np.dot(x, Wt))
        at = self.gapmap(Wt, x)
        if self.beta > 0:
            if at > self.beta or ystart in self.reveal:
                gammat = 0.0 
            else:
                self.inversegamma = self.inversegamma + 1
                gammat = np.min([0.5, self.gamma / np.sqrt(self.inversegamma)])
        
        zetat = np.round(gammat <= at)
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - zetat * at - (1-zetat) * gammat) * eystart + zetat * at * 1/self.K * np.ones(self.K) + (1 - zetat) * gammat * self.domsetvec
        self.ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(self.ytprime)
        
    
    
    def computePt(self, y):
        y = int(y)
        if self.reveal == self.outcomes:
            return(1.0)
        if not self.domsetdict:
            return(self.ptprime[y])
        else:
            return(np.sum(self.ptprime[self.domsetdict[str(y)]]))
        
        
    def update(self, y, x, loss):
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            Pt = self.computePt(y)
            scaling = 1/Pt
            self.algW.update(y, x, loss, scaling)
        
        
#%%
def zeromap(Wt, x):
    return(0)