import structure as struc
import structure_self_multilabel_fixed as struc_self
import numpy as np
import time
import os

def run_bandit_struc(y,K,X,dataname,rep,save,gamma,diameter,project=False, project_half = False, path=''):
    path_lossdata = path + "results_structured_prediciton_bandit" + dataname + str(gamma)
    if diameter != 2:
        path_lossdata = path + "results_structured_prediciton_bandit" + dataname + str(gamma) +'diameter' + str(diameter)
    if rep != 10:
        path_lossdata = path_lossdata +'rep' + str(rep)
    if project:
        if project_half:
            path_lossdata = path_lossdata + 'project_half' + ".npy"
        else:
            path_lossdata = path_lossdata + 'project' + ".npy"
    else:
        path_lossdata = path_lossdata + ".npy"
    
    if os.path.exists(path_lossdata):
        print("I have already run this experiment: " + path_lossdata)
        return("I have already run this experiment")
    mistakes = np.zeros((rep,len(y)))
    if gamma == 'theo':
        T=len(y)
        q = min(1,diameter*np.sqrt(K/T))
        for i in range(rep):
            mistakes[i] = struc.run_bandit_multiclass(y,K,X,q,diameter,project=project, project_half=project_half)
    if gamma == 'onlyT':
        q = 10**(-3)
        
        for i in range(rep):
            mistakes[i] = struc.run_bandit_multiclass(y,K,X,q,diameter,project=project, project_half=project_half)
    print('finish_bandit'+dataname)
    
    if save:
        np.save(path_lossdata, mistakes)
        
        
def run_bandit_struc_multilabel_fixed(y,K,m,X,dataname,rep,save,gamma,diameter,q_fixed=10**(-3),normed=True, project=False, project_half = False, path=''):
    start = time.time()
    path_lossdata_1 = path + "multilabel_fixed_general_results_structured_prediciton_bandit" + dataname + str(gamma)
    path_lossdata_2 = path + "multilabel_fixed_self_results_structured_prediciton_bandit" + dataname + str(gamma)
    if diameter != 2:
        path_lossdata_1 = path + "multilabel_fixed_general_results_structured_prediciton_bandit" + dataname + str(gamma) +'diameter' + str(diameter)
        path_lossdata_2 = path + "multilabel_fixed_self_results_structured_prediciton_bandit" + dataname + str(gamma) +'diameter' + str(diameter)
    path_lossdata_1 = path_lossdata_1 +'rep' + str(rep)
    path_lossdata_2 = path_lossdata_2 +'rep' + str(rep)
    if project:
        if project_half:
            path_lossdata_1 = path_lossdata_1 + 'project_half' + ".npy"
            path_lossdata_2 = path_lossdata_2 + 'project_half' + ".npy"
        else:
            path_lossdata_1 = path_lossdata_1 + 'project' + ".npy"
            path_lossdata_2 = path_lossdata_2 + 'project' + ".npy"
    else:
        path_lossdata_1 = path_lossdata_1 + ".npy"
        path_lossdata_2 = path_lossdata_2 + ".npy"
    
    if os.path.exists(path_lossdata_1):
        print("I have already run this experiment: " + path_lossdata_1)
        return("I have already run this experiment")
    mistakes_1 = np.zeros((rep,len(y)))
    mistakes_2 = np.zeros((rep,len(y)))
    if gamma == 'theo':
        T=len(y)
        q = min(1,diameter*np.sqrt(K/T))
        for i in range(rep):
            mistakes_1[i] = struc.run_bandit_multilabel_fixed(y,K,m,X,q,diameter,normed=normed,project=project, project_half=project_half)
            print('finish_general')
            mistakes_2[i] = struc_self.run_bandit_multilabel_fixed(y,K,m,X,q,diameter,normed=normed,project=project, project_half=project_half)
            print('finish_self')
    if gamma == 'onlyT':
        q = q_fixed
        for i in range(rep):
            mistakes_1[i] = struc.run_bandit_multilabel_fixed(y,K,m,X,q,diameter,normed=normed,project=project, project_half=project_half)
            print('finish_general')
            mistakes_2[i] = struc_self.run_bandit_multilabel_fixed(y,K,m,X,q,diameter,normed=normed,project=project, project_half=project_half)
            print('finish_self')
    end = time.time()
    time_diff = end-start
    print('finish_bandit'+dataname,'\n' ,time_diff)
    
    if save:
        np.save(path_lossdata_1, mistakes_1)
        np.save(path_lossdata_2, mistakes_2)