
import pandas as pd
import numpy as np
import copy
import os
#%%
def runclassifier(alg, lossf, y, X, rep = 1, verbose = True, dataname = "a secret", save = False, bandit = False, gamma = "theo"):
    if bandit:
        b = 'bandit'
    else:
        b = "fullinfo"
    path = "results/" + alg.name + lossf.name + dataname + str(gamma) + b + ".dat"
    if os.path.exists(path):
        print("I have already run this experiment: " + path)
        return("I have already run this experiment")
    mistakes = [[] for j in range(rep)]
    counter = 0
    freshalg = copy.deepcopy(alg)
    K = len(np.unique(y))
    for j in range(rep):
        alg = copy.deepcopy(freshalg)
        for i in range(X.shape[0]):
            xt = X[i, :]
            ytprime = alg.predict(xt)
            yt = int(y[i])
            mistakes[j].append(np.round(ytprime != yt))
            if bandit:
                if ytprime != yt:
                    yt = K + 1
            if "Gappletron" in alg.name:
                alg.update(yt, xt, lossf)
            else:
                alg.update(yt, xt)
            counter += 1
            if counter % 5000 == 0 and verbose:
                print("Algorithm: " + alg.name + " - " + lossf.name + " --- " + "Dataset: " + dataname + " --- " + "round " + str(i + 1) + " of repetition " + str(j + 1))
    mistakes = np.array(mistakes, dtype = 'float64')

    if save:
        data = {'algname': alg.name,
                'dataname': dataname,
                'lossname': lossf.name,
                'reps': rep,
                'K': len(np.unique(y)),
                'minL': np.min(np.sum(mistakes, axis = 1)),
                'maxL': np.max(np.sum(mistakes, axis = 1)),
                'meanL': np.sum(mistakes)/rep,
                'medianL': np.median(np.sum(mistakes, axis = 1)),
                'varL': np.sum((np.sum(mistakes, axis = 1) - np.sum(mistakes)/rep)**2)/rep,
                'bandit': bandit,
                'gamma': str(gamma)
                }
        df = pd.DataFrame(data, columns = ['algname', 'dataname', 'lossname', 'reps', 'K', 'minL', 'maxL', 'meanL', 'medianL', 'varL', 'bandit', 'gamma'], index=[0])
        
        pd.to_pickle(df, path)
    return(mistakes)

#%%
def runRA(alg, lossf, y, X, reveal, rep = 1, verbose = True, dataname = "a secret", save = False, RA = False, gamma = "theo"):
    if RA:
        b = 'RA'
    else:
        b = "fullinfo"
    path = "results/" + alg.name + lossf.name + dataname + str(gamma) + b + ".dat"
    if os.path.exists(path):
        print("I have already run this experiment: " + path)
        return("I have already run this experiment")
    mistakes = [[] for j in range(rep)]
    counter = 0
    freshalg = copy.deepcopy(alg)
    K = len(np.unique(y))
    for j in range(rep):
        alg = copy.deepcopy(freshalg)
        for i in range(X.shape[0]):
            xt = X[i, :]
            ytprime = alg.predict(xt)
            yt = int(y[i])
            mistakes[j].append(np.round(ytprime != yt))
            if RA:
                if ytprime != reveal:
                    yt = K + 1
            if "Gappletron" in alg.name:
                alg.update(yt, xt, lossf)
            else:
                alg.update(yt, xt)
            counter += 1
            if counter % 5000 == 0 and verbose:
                print("Algorithm: " + alg.name + " - " + lossf.name + " --- " + "Dataset: " + dataname + " --- " + "round " + str(i + 1) + " of repetition " + str(j + 1))
    mistakes = np.array(mistakes, dtype = 'float64')

    if save:
        data = {'algname': alg.name,
                'dataname': dataname,
                'lossname': lossf.name,
                'reps': rep,
                'K': len(np.unique(y)),
                'minL': np.min(np.sum(mistakes, axis = 1)),
                'maxL': np.max(np.sum(mistakes, axis = 1)),
                'meanL': np.sum(mistakes)/rep,
                'medianL': np.median(np.sum(mistakes, axis = 1)),
                'varL': np.sum((np.sum(mistakes, axis = 1) - np.sum(mistakes)/rep)**2)/rep,
                'RA': RA,
                'gamma': str(gamma)
                }
        df = pd.DataFrame(data, columns = ['algname', 'dataname', 'lossname', 'reps', 'K', 'minL', 'maxL', 'meanL', 'medianL', 'varL', 'RA', 'gamma'], index=[0])
        
        pd.to_pickle(df, path)
    return(mistakes)




#%%
def runLE(alg, lossf, y, X, rep = 1, verbose = True, dataname = "a secret", save = False, gamma = "theo"):
    b = "LE"
    path = "results/" + alg.name + lossf.name + dataname + str(gamma) + b + ".dat"
    if os.path.exists(path):
        print("I have already run this experiment: " + path)
        return("I have already run this experiment")
    mistakes = [[] for j in range(rep)]
    counter = 0
    freshalg = copy.deepcopy(alg)
    K = len(np.unique(y))
    reveal = K
    for j in range(rep):
        alg = copy.deepcopy(freshalg)
        for i in range(X.shape[0]):
            xt = X[i, :]
            ytprime = alg.predict(xt)
            yt = int(y[i])
            mistakes[j].append(np.round(ytprime != yt))
            if ytprime == reveal:
                if "Gappletron" in alg.name:
                    alg.update(yt, xt, lossf)
                else:
                    alg.update(yt, xt)
            counter += 1
            if counter % 5000 == 0 and verbose:
                print("Algorithm: " + alg.name + " - " + lossf.name + " --- " + "Dataset: " + dataname + " --- " + "round " + str(i + 1) + " of repetition " + str(j + 1))
    mistakes = np.array(mistakes, dtype = 'float64')

    if save:
        data = {'algname': alg.name,
                'dataname': dataname,
                'lossname': lossf.name,
                'reps': rep,
                'K': len(np.unique(y)),
                'minL': np.min(np.sum(mistakes, axis = 1)),
                'maxL': np.max(np.sum(mistakes, axis = 1)),
                'meanL': np.sum(mistakes)/rep,
                'medianL': np.median(np.sum(mistakes, axis = 1)),
                'varL': np.sum((np.sum(mistakes, axis = 1) - np.sum(mistakes)/rep)**2)/rep,
                'LE': True,
                'gamma': str(gamma)
                }
        df = pd.DataFrame(data, columns = ['algname', 'dataname', 'lossname', 'reps', 'K', 'minL', 'maxL', 'meanL', 'medianL', 'varL', 'LE', 'gamma'], index=[0])
        
        pd.to_pickle(df, path)
    return(mistakes)