#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import scipy as sp

#%%
class lootOMD:
    def __init__(self, dim, nit, T, beta, alpha):
        self.dim = dim
        self.name = "lootOMD"
        self.pars = np.ones(dim)/dim
        self.cumloss = np.zeros(dim)
        self.t = 1
        self.eta = 1.0
        self.gamma = 1.0
        self.nit = nit
        self.nlosses = 0.0
        self.T = T
        self.vbarsum = np.zeros(dim)
        self.visum = np.zeros(dim)
        self.beta = beta
    
    
    def Objective(self, p, ltil, etavec):
        return(np.sum(p * ltil + 1/etavec * ((p * np.log(p/self.pars)) - p + self.pars)))
    
    def constraint1(self, x):
        return np.sum(x) - 1
    def constraint2(self, x):
        return x

    
    def compp(self, ltil, etavec):
        bounds = [(1/(self.dim * self.T)**2, None) for _ in range(self.dim)]
        cons = {'type': 'eq', 'fun': self.constraint1}
        sol = sp.optimize.minimize(self.Objective, x0 = self.pars, args = (ltil, etavec), bounds = bounds, constraints = [cons])
        return(sol.x)

    
    def update(self, lossvec):
        rt = np.dot(self.pars, lossvec) - lossvec
        vt = rt**2
        vtbar = np.dot(self.pars,vt)
        self.vbarsum += vtbar
        self.visum += vt
        if self.vbarsum[0] > 0:
            etavec = self.beta * (np.maximum(self.visum, self.vbarsum))**(-1/2)
            ltil = -rt * (np.abs(rt) <= 1/etavec)
            self.pars = self.compp(ltil, etavec)
            
            
            
#%%
class lootFTRL:
    def __init__(self, dim, nit, T, beta, alpha):
        self.dim = dim
        self.name = "lootFTRL"
        self.pars = np.ones(dim)/dim
        self.cumloss = np.zeros(dim)
        self.t = 1
        self.eta = 1.0
        self.gamma = 1.0
        self.nit = nit
        self.nlosses = 0.0
        self.T = T
        self.vbarsum = np.zeros(dim)
        self.visum = np.zeros(dim)
        self.beta = beta
        self.R = np.zeros(dim)
    
    
    def Objective(self, p, ltil, etavec):
        return(np.sum(p * ltil + 1/etavec * ((p * np.log(p * self.dim)) - p + 1/self.dim)))
    
    def constraint1(self, x):
        return np.sum(x) - 1
    def constraint2(self, x):
        return x

    
    def compp2(self, ltil, etavec):
        bounds = [(1/(T*K)**10, None) for _ in range(self.dim)]
        cons = {'type': 'eq', 'fun': self.constraint1}
        sol = sp.optimize.minimize(self.Objective, x0 = self.pars, args = (ltil, etavec), bounds = bounds, constraints = [cons])
        return(sol.x)
    
    
    def Lagrange(self, Lda, ltil, etavec):
        x = etavec*(ltil + Lda)
        return(np.sum(np.exp(-x - np.max(x))/np.exp(-np.max(x))) - 1)
    
    def compp(self, ltil, etavec):
        p = self.compp2(ltil, etavec)
        return(p)
        sol = sp.optimize.root(self.Lagrange, args = (ltil, etavec), x0 = 0, method = 'lm')
        if np.sum(np.exp(-etavec*(ltil + sol.x)) - 1) > 1/100:
            p = self.compp2(ltil, etavec)
            return(p)
        if sol.success:
            p = np.exp(-etavec*(ltil + sol.x))
        else:
            p = self.compp2(ltil, etavec)
        return(p)
    
        
    def update(self, lossvec):
        rt = np.dot(self.pars, lossvec) - lossvec
        vt = rt**2
        vtbar = np.dot(self.pars, vt)
        self.vbarsum += vtbar
        self.visum += vt
        if self.vbarsum[0] > 0:
            etavec = self.beta * (np.maximum(self.visum, self.vbarsum))**(-1/2)
            ltil = -rt * (np.abs(rt) <= 1/etavec)
            self.R += ltil
            self.pars = self.compp(self.R, etavec)
            
            

#%%

def stable_softmax(x):
    z = x - np.max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    softmax = numerator / denominator
    return softmax

class Squint:
    def __init__(self, dim, nit, T, beta, alpha):
        self.dim = dim 
        self.name = "squint"  # improper prior
        self.alpha = alpha  # loss range
        self.pars = np.ones(dim) / dim
        self.cumloss = np.zeros(dim)
        self.t = 1
        self.eta = 1.0
        self.gamma = 1.0
        self.nit = nit
        self.nlosses = 0.0
        self.T = T
        self.vbarsum = np.zeros(dim)
        self.visum = np.zeros(dim)
        self.R = np.zeros(self.dim)
        self.V = np.zeros(self.dim)
    
    def log_diff(self, log_p, log_q):
        return log_p + np.log1p(-np.exp(log_q - log_p))
        
    def logerfcdif(self, z0, z1):
        if z0 < 0 and z1 < 0:
            out = self.log_diff(sp.special.log_ndtr(z1 * np.sqrt(2)), sp.special.log_ndtr(z0 * np.sqrt(2))) + np.log(2)
        elif (z0 < 0 and z1 >= 0) or (z0 >= 0 and z1 < 0):
            out = np.log(2 - sp.special.erfc(-z0) - sp.special.erfc(z1))
        elif z0 >= 0 and z1 >= 0:
            out = self.log_diff(sp.special.log_ndtr(-z0 * np.sqrt(2)), sp.special.log_ndtr(-z1 * np.sqrt(2))) + np.log(2)
        return(out)
        
    def update(self, lossvec):
        rt = np.dot(self.pars, lossvec) - lossvec
        self.R += rt
        vt = rt ** 2
        self.V += vt
        logev = np.zeros(self.dim)
        for i in range(self.dim):
            if self.V[i] == 0:
                logev[i] = np.log(np.sqrt(np.pi))
            else:
                logev[i] = np.log(np.pi / 2) + self.R[i] ** 2 / (4 * self.V[i]) + self.logerfcdif(-self.R[i] / np.sqrt(4 * self.V[i]), (1 / (self.alpha) * self.V[i] - self.R[i]) / np.sqrt(4 * self.V[i])) - np.log(2*np.sqrt(self.V[i]))
        self.logev = logev
        self.pars = stable_softmax(logev)



#%%

class AdaHedge:
    def __init__(self, dim):
        self.dim = dim
        self.pars = np.ones(dim) / dim
        self.cumulative_lossvec = np.zeros(dim)
        self.delta = 0
        self.eta = 1.0  # Initial learning rate
        
    def mix(self, L):
        mn = np.min(L)
        if self.delta == 0:
            w = np.array(range(self.dim)) == np.argmin(L)
        else:
            w = np.exp(-np.log(self.dim)/self.delta * (L - mn))
        s = np.sum(w)
        p = w/s
        M = mn - self.delta * np.log(s/np.log(self.dim))/np.log(self.dim)
        return([p, M])
        
        
    def update(self, lossvec):
        tmp1 = self.mix(self.cumulative_lossvec)
        Mold = tmp1[1]
        
        # Update cumulative lossvec
        ht = np.dot(self.pars, lossvec)
        self.cumulative_lossvec += lossvec
        
        
        tmp2 = self.mix(self.cumulative_lossvec)
        self.pars = tmp2[0]
        M = tmp2[1]
        delta = np.max([0, ht - (M - Mold)])
        self.delta += delta



class AdaHedge2:
    def __init__(self, dim, beta = 1, alpha = 1/100000):
        self.dim = dim
        self.pars = np.ones(dim) / dim
        self.R = np.zeros(dim)
        self.delta = 0
        self.eta = 1/alpha  # Initial learning rate
        self.alpha = alpha
        self.vbarsum = np.zeros(dim)
        self.beta = beta
        
    def update(self, lossvec):
        rt = np.dot(self.pars, lossvec) - lossvec
        vt = rt**2
        vtbar = np.dot(self.pars, vt)
        self.vbarsum += vtbar
        if self.vbarsum[0] > 0:
            eta = np.min([1/self.alpha, self.beta * (self.vbarsum[0])**(-1/2)])
            ltil = rt 
            self.R += ltil
            self.pars = stable_softmax(eta * self.R)
        

#%%

#lower bound environment

np.random.seed(5624)

nK = 500
reps = 40

allregs = []
allregs2 = []
allregs3 = []
allregs4 = []
allregs5 = []
allregs6 = []

for s in range(nK):
    

    
    # KT st T = 20 * K
    K = 15 + s * 10
    T = 20 * K #1000
    R = T * K
    p = 1/(T)
    ls = np.append([0], np.repeat(1, K-1))
    shape = 2
    sgn = -1
    mn = sgn * np.sqrt(K/T)
    rade = 0 # 1 for no rademacher noise, 0 for rademacher noise
    minloss = 0 + mn * rade
    regOMD = []
    regFTRL = []
    regADA = []
    regADA2 = []
    regSQ = []
    
    for i in range(reps):
        z = sgn * np.sqrt(R)*np.random.binomial(1, p, size = K) * (np.random.binomial(1, 1/2 + 0.5 * rade, size = K)-0.5)*2
        lvt = np.append(np.array([1, 0, 1]), np.repeat(1, K-3)) + z
        g = 1
        lossmat = lvt
        for j in range(T-1):
            z =  sgn * np.sqrt(R)*np.random.binomial(1, p, size = K) * (np.random.binomial(1, 1/2 + 0.5 * rade, size = K)-0.5)*2
            if g == 1:
                lvt = np.append(np.array([1, 0]), np.repeat(1, K-2)) + z
                g = 1
            else:
                lvt = np.append(np.array([0, 1]), np.repeat(1, K-2)) +  z
                g = 1
            lossmat = np.vstack((lossmat, lvt))
        
        alphaE = np.max(np.abs(lossmat))
        iOMD = lootOMD(K, 1000, T, np.sqrt(np.log(T*K)), 1)
        iFTRL = lootFTRL(K, 1000, T, np.sqrt(np.log(K)), 1)
        iSQ = Squint(K, 1000, T, 1, alphaE)
        iADA = AdaHedge(K)#, beta = np.sqrt(np.log(K)), alpha =  alphaE)
        iADA2 = AdaHedge2(K, beta = np.sqrt(np.log(K)), alpha =  alphaE)
        
        lossOMD = 0
        lossFTRL = 0
        lossADA = 0
        lossADA2 = 0
        lossSQ = 0
        
        for j in range(T):
            lvt = lossmat[j, :].flatten()
            lossOMD += np.dot(iOMD.pars, lvt)
            lossFTRL += np.dot(iFTRL.pars, lvt)
            lossADA += np.dot(iADA.pars, lvt)
            lossADA2 += np.dot(iADA2.pars, lvt)
            lossSQ += np.dot(iSQ.pars, lvt)
            
            iOMD.update(lvt)
            iFTRL.update(lvt)
            iSQ.update(lvt)
            iADA.update(lvt)
            iADA2.update(lvt)
    
        regOMD.append(np.sum(lossOMD) - minloss)
        regFTRL.append(np.sum(lossFTRL) - minloss)
        regADA.append(np.sum(lossADA) - minloss)
        regADA2.append(np.sum(lossADA2) - minloss)
        regSQ.append(np.sum(lossSQ) - minloss)
    
    allregs3.append(np.median([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs2.append(np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs4.append(np.quantile([regOMD, regFTRL, regADA, regADA2, regSQ], 0.2, axis = 1))
    allregs5.append(np.quantile([regOMD, regFTRL, regADA, regADA2, regSQ], 0.8, axis = 1))
    allregs6.append(np.std([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs.append([np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1), K])
    np.save("heavydata.npy", np.array(allregs2))
    np.save("heavydata2.npy", np.array(allregs3))
    
    np.save("heavydataQ2.npy", np.array(allregs4))
    np.save("heavydataQ8.npy", np.array(allregs5))
    np.save("heavydataSD.npy", np.array(allregs6))
    print([np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1), K])







#%%

#non heavy-tailed environment

np.random.seed(5624)

nK = 500
reps = 40

allregs = []
allregs2 = []
allregs3 = []
allregs4 = []
allregs5 = []
allregs6 = []

for s in range(nK):
    

    
    # KT st T = 20 * K
    K = 15 + s * 10
    T = 20 * K #1000
    R = 2
    p = 1/(T)
    ls = np.append([0], np.repeat(1, K-1))
    shape = 2
    sgn = -1
    mn = sgn * np.sqrt(K/T)
    rade = 0 # 1 for no rademacher noise, 0 for rademacher noise
    minloss = 0 + mn * rade
    regOMD = []
    regFTRL = []
    regADA = []
    regADA2 = []
    regSQ = []
    
    for i in range(reps):
        z = sgn * (R)*np.random.binomial(1, p, size = K) * (np.random.binomial(1, 1/2 + 0.5 * rade, size = K)-0.5)*2
        lvt = np.append(np.array([1, 0, 1]), np.repeat(1, K-3)) + z
        g = 1
        lossmat = lvt
        for j in range(T-1):
            z =  sgn * (R)*np.random.binomial(1, p, size = K) * (np.random.binomial(1, 1/2 + 0.5 * rade, size = K)-0.5)*2
            if g == 1:
                lvt = np.append(np.array([1, 0]), np.repeat(1, K-2)) + z
                g = 1
            else:
                lvt = np.append(np.array([0, 1]), np.repeat(1, K-2)) +  z
                g = 1
            lossmat = np.vstack((lossmat, lvt))
        
        alphaE = np.max(np.abs(lossmat))
        iOMD = lootOMD(K, 1000, T, np.sqrt(np.log(T*K)), 1)
        iFTRL = lootFTRL(K, 1000, T, np.sqrt(np.log(K)), 1)
        iSQ = Squint(K, 1000, T, 1, alphaE)
        iADA = AdaHedge(K)#, beta = np.sqrt(np.log(K)), alpha =  alphaE)
        iADA2 = AdaHedge2(K, beta = np.sqrt(np.log(K)), alpha =  alphaE)
        
        lossOMD = 0
        lossFTRL = 0
        lossADA = 0
        lossADA2 = 0
        lossSQ = 0
        
        for j in range(T):
            lvt = lossmat[j, :].flatten()
            lossOMD += np.dot(iOMD.pars, lvt)
            lossFTRL += np.dot(iFTRL.pars, lvt)
            lossADA += np.dot(iADA.pars, lvt)
            lossADA2 += np.dot(iADA2.pars, lvt)
            lossSQ += np.dot(iSQ.pars, lvt)
            
            iOMD.update(lvt)
            iFTRL.update(lvt)
            iSQ.update(lvt)
            iADA.update(lvt)
            iADA2.update(lvt)
    
        regOMD.append(np.sum(lossOMD) - minloss)
        regFTRL.append(np.sum(lossFTRL) - minloss)
        regADA.append(np.sum(lossADA) - minloss)
        regADA2.append(np.sum(lossADA2) - minloss)
        regSQ.append(np.sum(lossSQ) - minloss)
    
    allregs3.append(np.median([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs2.append(np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs4.append(np.quantile([regOMD, regFTRL, regADA, regADA2, regSQ], 0.2, axis = 1))
    allregs5.append(np.quantile([regOMD, regFTRL, regADA, regADA2, regSQ], 0.8, axis = 1))
    allregs6.append(np.std([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1))
    allregs.append([np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1), K])
    allregs.append([np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1), K])
    np.save("nonheavydata.npy", np.array(allregs2))
    np.save("nonheavydata2.npy", np.array(allregs3))
    
    np.save("nonheavydataQ2.npy", np.array(allregs4))
    np.save("nonheavydataQ8.npy", np.array(allregs5))
    np.save("nonheavydataSD.npy", np.array(allregs6))
    print([np.mean([regOMD, regFTRL, regADA, regADA2, regSQ], axis = 1), K])




















