seed = 100

DEBUG = False 
dump_dataset = True
dataset = 'E1'
PE = False
query_path = None
#query_path = 'GPS-202108130219-randnum5'
#query_path = 'GPS-202109251744'
#query_path = 'baseline'

################# Query #####################
query_num = 5
diff_embedding = True
hard_softmax = True 

lr_scheduler_step_size = 40
gamma = 0.1
learn_rate = 0.0001
######### MI3 ########
dist_dim = 256 # == dense_output_size/2
ns_gamma = 1
distribution = 'Normal'
#distribution = 'Beta'

######### MI2 ########
noise = False 
gumbel = False
hellinger = False 
latent_code = True

######### MI #########
t_space = True
io_recon = True
program_recon = True

lambda_ps  = 1.0
lambda_z_n = 0.001 
lambda_t_n = 0.0001 
lambda_z_t = 0.001 

lambda_io = 0.001
lambda_program = 0.001 
#############################################


# General training
num_epochs = 100
patience = 20
#model_output_path = 'trained_models/E1/PE_model' #'trained_models/E1/GPS_model/'
if PE:
    model_output_path = 'trained_models/' + dataset + '/PE-'
else:
    model_output_path = 'trained_models/' + dataset + '/GPS-'


load_from_checkpoint = None
#load_from_checkpoint = '202109251751'
#load_from_checkpoint = 'trained_models/E1/GPS-202107221144/model-latest' 
max_len = None #For debugging set a lower number, use None for running training for full data
if dataset == 'E1':
    max_prog_len = 4
else:
    max_prog_len = 12

# File params (Inference)
global_model_path = 'trained_models/'+dataset+'/GPS_model/best_model.th' #GPS model path
PE_model_path = 'trained_models/'+dataset+'/PE_model/best_model.th' #PE model path
result_path = 'results/'

# GPS and PE model training params
if PE:
    train_path = 'data/data/'+dataset+'/train_dataset_pe'
    val_path = 'data/data/'+dataset+'/val_dataset_pe'
else:
    train_path = 'data/data/'+dataset+'/train_dataset_gps'
    val_path = 'data/data/'+dataset+'/val_dataset_gps'
batch_size = 128 // 4
val_iterator_size = 32

#    train_path = val_path

#Aggregator training params
att_batch_size = 100 #256

#Inference params
search_method = 'beam'
num_workers = 1
max_beam_size = 819200
dfs_max_width = 50
cab_beam_size = 100
cab_width = 10
cab_width_growth = 10

# DSL params
integer_min = -256
integer_max = 255
integer_range = integer_max - integer_min + 1
max_list_len = 20
num_inputs = 3
num_statements = 1298
num_operators = 38

# Program State Params
max_program_len = 8
max_program_vars = max_program_len + num_inputs
state_len = max_program_vars + 1
state_dim = 256

# H_theta and W_phi network params
type_vector_len = 2
embedding_size = 20
var_encoder_size = 56
dense_output_size = int(dist_dim * 2)
dense_num_layers = 10
dense_growth_size = 56

#################################################
kernel_size = 3
conv_stack = [64,64,64]
fc_stack = [dense_output_size]
tgt_embedding_size = 256
lstm_hidden_size = 256
nb_lstm_layers = 2

nb_ios = 5

train_file = 'datasets/karel/1m_6ex_karel/train.json'
val_file = 'datasets/karel/1m_6ex_karel/val.json'
vocab_file = 'datasets/karel/1m_6ex_karel/new_vocab.vocab'
vocab_size = 52

############ query ps #############
if query_path is not None:
    print('Training with Queried Data!')
    train_path = 'trained_models/' + dataset + '/' + query_path + '/train_gps'
    val_path = 'trained_models/' + dataset + '/' + query_path + '/val_gps'
    if dataset == 'E1':
        batch_size = 32 // 4
    elif dataset == 'E2':
        batch_size = 100 // 4
    else:
        raise ValueError('Wrong dataset')
    learn_rate = 0.0001
###################################

############## debug ###############
if DEBUG:
    model_output_path = 'trained_models/debug/'
####################################

def s2intL(str_args):
    return [int(val) for val in str_args.split(',')]
