# -*- coding: utf-8 -*-
import compared
import scalefreeMD
import loss
import runclass

import numpy as np

def runbandit(y, X, dataname, rep = 10, save = True, gamma = "theo", diameter = 2, project=False, project_half=False , path='') :
    K = len(np.unique(y))
    xnorm = np.max(np.linalg.norm(X, axis = 1))
    Ti = X.shape[0]
    
    reg = scalefreeMD.L2reg() 
    algW = scalefreeMD.scalefreeMD(reg, diameter, dim = X.shape[1], K = K, eta = 'ada', project = project, project_half=project_half)
    
    if gamma == "theo":
        Gapgamma = 1/2 * xnorm * K *0.5 * diameter
        Gaphingamma = np.min([1, np.sqrt(K**3 * xnorm**2/(2 * (1 - 1/K) * (K - 1) * Ti))])
        etascalehin = 1
        Gapsmhingamma = np.min([1, np.sqrt(4 * K**2 * xnorm**2/(Ti))])
        etascalesmhin = 1
        Gaploggamma = np.min([1, np.sqrt(K**2 * xnorm**2/(np.log(2) * Ti))])
        etascalelog=1
    
    if gamma == "onlyT":
        Gapgamma = 1/2
        Gaphingamma = np.min([1, np.sqrt(1/(2 * Ti))]) 
        etascalehin = 1/((1 - 1/K)/(K * xnorm ** 2))
        Gapsmhingamma = np.min([1, np.sqrt(4 /(Ti))])
        etascalesmhin =  (4 * K * xnorm ** 2)
        Gaploggamma = np.min([1, np.sqrt(1/(np.log(2) * Ti))])
        etascalelog = 1/((np.log(2))/(2 * K * xnorm ** 2))
        np.min([0.5, (np.log(X.shape[0])/X.shape[0])**(1/2)])
        xnorm = 1
    
    if type(gamma) == float or type(gamma) == int:
        Gapgamma = gamma
    
    algWhin = scalefreeMD.scalefreeMD(reg, 1, dim = X.shape[1], K = K, eta = 'constant', project = project, project_half=project_half, lada = Gaphingamma * etascalehin * (1 - 1/K)/(K * xnorm ** 2))
    algWlog = scalefreeMD.scalefreeMD(reg, 1, dim = X.shape[1], K = K, eta = 'constant', project = project, project_half=project_half, lada = Gaploggamma * etascalelog * (np.log(2))/(2 * K * xnorm ** 2))

    
    rep = rep
    Gaplogistic = compared.Gappletron(loss.logistic(base = K).minloss, algW, Gapgamma, 1, list(range(K)), domset = [], reveal = [], domsetdict = [])
    
    
    Gaptronhinge = compared.Gaptron(loss.hinge(1/K).minloss, algWhin, Gaphingamma, 1, list(range(K)), domset = [], reveal = [], domsetdict = [])
    Gaptronlogistic = compared.Gaptron(compared.logmap, algWlog, Gaploggamma, 1, list(range(K)), domset = [], reveal = [], domsetdict = [])


    mistakesgaplogistic = runclass.runclassifier(Gaplogistic, loss.logistic(base = K), y, X, rep = rep, dataname = dataname, save = save, bandit = True, gamma = gamma, diameter=diameter, project=project, project_half=project_half, path=path)
    
    mistakesgaptronhinge = runclass.runclassifier(Gaptronhinge, loss.hinge(1/K), y, X, rep = rep, dataname = dataname, save = save, bandit = True, gamma = gamma, diameter=diameter, project=project, project_half=project_half, path=path)
    mistakesgaptronlogistic = runclass.runclassifier(Gaptronlogistic, loss.logistic(base = 2), y, X, rep = rep, dataname = dataname, save = save, bandit = True, gamma = gamma, diameter=diameter, project=project, project_half=project_half, path=path)
    
    
