import os
import sys
import subprocess
import itertools
import numpy as np
 
curr_dir=os.getcwd()
config_dir=curr_dir+'/config'
if not os.path.exists(config_dir):
    os.makedirs(config_dir)


## SEEDS
SEED_tab=list(range(1))#[1]#list(range(1))

np.save(config_dir+"/seeds.npy",np.array(SEED_tab))

MTRAIN_tab =[10000]#20000
np.save(config_dir+"/mtrain.npy",np.array(MTRAIN_tab))

MTEST_tab =[1000]#2000
np.save(config_dir+"/mtest.npy",np.array(MTEST_tab))

D_tab = [15]
np.save(config_dir+"/d.npy",np.array(D_tab))

P_tab = [6]
np.save(config_dir+"/p.npy",np.array(P_tab))

K_tab = [5]
np.save(config_dir+"/k.npy",np.array(K_tab))

NEURON_tab = [5000]#,500,1000]#10,50,100,250,500,1000]#,20,60,150] #5
np.save(config_dir+"/neuron.npy",np.array(NEURON_tab))

EPOCH=2500#80

OPT_tab = []

##GD
OPT_tab.append("SGD")
opt_tab = ["SGD"]
LR_tab=[3e-2,5e-2,7e-2,9e-2,0.2,0.35,0.5,0.7]#np.linspace(7e-3,2e-2,10)#5e-3,6e-3,7e-3,8e-3,9e-3,1e-2]#np.linspace(3e-2,7e-2,10)#,0.2,0.3,0.4,0.5]

LR_tab=[0.07,0.05,0.03,0.005,0.5]

#LR_tab=[0.001]#[0.005,0.003,0.001]

BATCH_tab = [200]#,1024]
M_tab =[0.9]
#A_tab=[(40,75)]#,(25,40)]

A_tab=[(400,600)]

WD_tab=[5e-4,3e-4,6e-4,7e-4,5e-5,3e-5,7e-5,3e-3,5e-3,7e-3]

WD_tab=[9e-3]
LR_tab=[0.0003]

#WD_tab=[5e-4,3e-4,7e-4,5e-3,3e-3,7e-3,5e-2,3e-2,7e-2,5e-1,3e-1,7e-1]


list_param_mm=[SEED_tab,opt_tab,LR_tab,
              BATCH_tab,M_tab,A_tab,WD_tab, MTRAIN_tab, MTEST_tab,
              D_tab, P_tab,K_tab,NEURON_tab ]


list_param_gdm= list(itertools.product(*list_param_mm))

np.save(config_dir+"/opt_algs.npy",np.array(OPT_tab))


list_param = list_param_gdm#list_param_gdm#list_param_gd+list_param_gdm+list_param_sgd+list_param_sgdm

SAVE="False"

#print(list_param)

for l in list_param:
   
   SEED=l[0]
   OPT=l[1]
   LR=l[2]
   BATCH=l[3]
   M=l[4]
   A=l[5]
   WD=l[6]
   MTRAIN=l[7]
   MTEST=l[8]
   D=l[9]
   P=l[10]
   K=l[11]
   NEURON=l[12]

   FA=A[0]
   SA=A[1]


   subprocess.call(['sbatch', 'cifar_exec.slurm', str(SEED), OPT,\
                  str(LR), str(BATCH), str(M), str(FA), str(SA), str(WD),
                  str(EPOCH), str(MTRAIN), str(MTEST), str(D), str(P), str(K),str(NEURON),str(SAVE)])



                     
print('done')
