import os
import sys
import subprocess
import itertools
import numpy as np
import random

curr_dir=os.getcwd()
config_dir=curr_dir+'/config'
if not os.path.exists(config_dir):
    os.makedirs(config_dir)


#OPT_tab = ["SGD","ADAM","GRAFTL"]

#np.save(config_dir+"/opt_algs.npy",np.array(OPT_tab))


## SEEDS

SAVE="False"

if SAVE=="True":

   SEED_tab=list(range(7))

else:

   SEED_tab=list(range(7))


ARCH_tab=['res32']#["res32"]

np.save(config_dir+"/seeds.npy",np.array(SEED_tab))


DATA_tab=['celeba']#stl10']#['cifar']#["lsun"]#["cifar"]#["lsun"]

np.save(config_dir+"/dataset.npy",np.array(DATA_tab))



OPT_tab=[]


EPOCH=1

OPT_tab.append("SGD")

optim_tab=["SGD"]
DLR_tab=[7e-4,8e-4,9e-4,1e-3,2e-3]
GLR_tab=list(np.array(DLR_tab)*.1)#[4e-5,5e-5,6e-5,7e-5,8e-5,9e-5,1e-4,2e-4]
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))
BATCH_tab = [64]

M_tab=[0.2,0.3,0.4, 0.5]
WD_tab=[0]
GAMMA_tab=[0]
A_tab=[(0,0)]
BETA2_tab=[0]
Q_tab=[0]


list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab, ARCH_tab,BETA2_tab,GAMMA_tab, DATA_tab, Q_tab]
list_param_sgd= list(itertools.product(*list_optim))




################################################################################
################################################################################
################################################################################
## Adam

#OPT_tab.append("ADAM")

optim_tab=["ADAM"]
DLR_tab=[5e-4]#[8e-4, 9e-4, 1e-3, 2e-3 ]
GLR_tab=list(np.array(DLR_tab)*0.1)#[4e-5,5e-5,6e-5,7e-5,8e-5,9e-5,1e-4,2e-4]
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))

BATCH_tab = [64]#[64]
M_tab= [0.]#[0, 0.1, 0.2, 0.3, 0.4, 0.5]#[0, 0.2, 0.4, 0.5]

A_tab = [(0,0)]
# do large batch adam and set up threshold appropriately
WD_tab=[0]
GAMMA_tab=[0]


BETA2_tab=[0.7]#[.5, .6, .7]#[.5, .6, .7]
Q_tab=[0]


list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab,ARCH_tab,BETA2_tab,GAMMA_tab, DATA_tab, Q_tab]
list_param_adam= list(itertools.product(*list_optim))



################################################################################
################################################################################
################################################################################
## Graft

OPT_tab.append("GRAFTL")


optim_tab=["GRAFTL"]
DLR_tab=[2e-4,3e-4, 4e-4, 5e-4]#[2e-4,3e-4, 4e-4]
GLR_tab=list(np.array(DLR_tab))
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))


BATCH_tab = [64]
M_tab =[0.2, .4, .5, .6]

A_tab = [(0,0)]

WD_tab=[0]
GAMMA_tab=[0]
BETA2_tab=[.5, .6, .7]#,0.6,0.7]
Q_tab=[0]

list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab,ARCH_tab,BETA2_tab,GAMMA_tab, DATA_tab, Q_tab]
list_param_graft= list(itertools.product(*list_optim))




################################################################################
################################################################################
################################################################################
## ALL Graft

#OPT_tab.append("ALL_GRAFTL")

#optim_tab=["ALL_GRAFTL"]
#DLR_tab=[3e-4]#,5e-4,6e-4,7e-4,8e-4,9e-4,1e-3,2e-3]
#GLR_tab=list(np.array(DLR_tab)*0.1)
#GLR_tab=[float(round(i,6)) for i in GLR_tab]
#LR_tab=[]
#for i in range(len(GLR_tab)):
#    LR_tab.append((GLR_tab[i],DLR_tab[i]))


#BATCH_tab = [64]
#M_tab =[0.6]

#A_tab = [(0,0)]

#WD_tab=[0]
#GAMMA_tab=[0]
#BETA2_tab=[0.7]
#Q_tab=[0]


#list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab,ARCH_tab,BETA2_tab,GAMMA_tab, DATA_tab, Q_tab]
#list_param_allgraft= list(itertools.product(*list_optim))



################################################################################
################################################################################
################################################################################
## Normalized gradient

OPT_tab.append("NORMALIZED")

optim_tab=["NORMALIZED"]
DLR_tab=[0.15]
#[ 0.1,0.11, 0.12, 0.13, 0.14,  0.15, 0.16, 0.17, 0.18, 0.19,  0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.28]
GLR_tab=list(np.array(DLR_tab)*0.1)
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))
BATCH_tab = [64]
M_tab=[0]
WD_tab=[0]
GAMMA_tab=[0]
A_tab=[(0,0)]
BETA2_tab=[0]
Q_tab=[0]

list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab, ARCH_tab,BETA2_tab,GAMMA_tab,DATA_tab, Q_tab]
list_param_normalized= list(itertools.product(*list_optim))



################################################################################
################################################################################
################################################################################
## Global Normalized gradient

OPT_tab.append("GLOBAL_NORMALIZED")

optim_tab=["GLOBAL_NORMALIZED"]
DLR_tab=[1.7]#[1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4]#[2,2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3, 3.1,3.2,3.3, 3.4, 3.5,3.6,3.7,3.8,3.9,4]
GLR_tab=list(np.array(DLR_tab)*0.1)
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))
BATCH_tab = [64]
M_tab=[0.]
WD_tab=[0]
GAMMA_tab=[0]
A_tab=[(0,0)]
BETA2_tab=[0]
Q_tab=[0]

list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab, ARCH_tab,BETA2_tab,GAMMA_tab,DATA_tab, Q_tab]
list_param_globalnormalized= list(itertools.product(*list_optim))



################################################################################
################################################################################
################################################################################
## Adam

#OPT_tab.append("ADAM")

optim_tab=["NORMALIZED_ADAM"]
DLR_tab=[3e-4]#[8e-4, 9e-4, 1e-3, 2e-3 ]
GLR_tab=list(np.array(DLR_tab)*0.1)#[4e-5,5e-5,6e-5,7e-5,8e-5,9e-5,1e-4,2e-4]
GLR_tab=[float(round(i,6)) for i in GLR_tab]
LR_tab=[]
for i in range(len(GLR_tab)):
    LR_tab.append((GLR_tab[i],DLR_tab[i]))

BATCH_tab = [64]#[64]
M_tab= [0.2]#[0, 0.1, 0.2, 0.3, 0.4, 0.5]#[0, 0.2, 0.4, 0.5]

A_tab = [(0,0)]
# do large batch adam and set up threshold appropriately
WD_tab=[0]
GAMMA_tab=[0]


BETA2_tab=[0.6]#[.5, .6, .7]#[.5, .6, .7]
Q_tab=[0]


list_optim= [SEED_tab,LR_tab,A_tab,BATCH_tab,M_tab,WD_tab,optim_tab,ARCH_tab,BETA2_tab,GAMMA_tab, DATA_tab, Q_tab]
list_param_normadam= list(itertools.product(*list_optim))






## 1 to remove in the save
np.save(config_dir+"/opt_algs2.npy",np.array(OPT_tab))




list_param = list_param_globalnormalized#globalnormalized#normalized#sgd#globalnormalized#normalized#normadam#normalized#globalnormalized#globalnormalized#graft#normalized#+list_param_adam#+list_param_normalized#list_param_adam#list_param_smoothedallgraft#+list_param_normalized+list_param_allnormalized

#+list_param_normalizedsgdrecovergraft







#list_param_smoothedallgraft#list_param_sgdrecovergraft#list_param_finalrecovallgraft#list_param_allgraft##list_param_finalrecovallgraft###list_param_allgraft##list_param_sgdrecovergraft#list_param_normalizedsgdrecovergraft #+ list_param_allgraft# list_param_normalizedsgdrecovergraft
#list_param_allgraft#list_param_sgd+list_param_adam+list_param_graft#list_param_sgdadam+list_param_adamsgd#+list_param_graft#+list_param_adam


for l in list_param:

   SEED=l[0]
   LR=l[1]
   A=l[2]
   BATCH=l[3]
   M=l[4]
   WD=l[5]
   OPT_CHOICE=l[6]
   ARCH=l[7]
   BETA2=l[8]
   GAMMA=l[9]
   DATA=l[10]
   Q = l[11]

   FA=A[0]
   SA=A[1]

   LR_G=LR[0]
   LR_D=LR[1]

#   M_G=M[0]
#   M_D=M[1]

#   WD_G=WD[0]
#   WD_D=WD[1]


   subprocess.call(['sbatch', 'cifar_exec.slurm', str(SEED), str(LR_G), str(LR_D),\
                  str(BATCH), str(M), str(FA), str(SA), str(WD),\
                  str(SAVE), str(OPT_CHOICE),str(ARCH), str(BETA2), str(GAMMA), str(DATA), str(Q), str(EPOCH)])


   #subprocess.call(['sbatch', 'cifar_exec.slurm', str(SEED), str(LR_G), str(LR_D),  str(BATCH), str(M_G), str(M_D), str(FA), str(SA), str(WD_G), str(WD_D), str(SAVE),str(OPT_CHOICE)])

                                                                                                                                                                                                                                
