ENV["JULIA_CUDA_USE_BINARYBUILDER"]=false
using Pkg

package_dir = "julia_gnn"
Pkg.activate(package_dir)

using CUDA, Flux

include(string(package_dir,"/src/DBModule.jl"))
using .DBModule


# dataset="ogbn-molhiv"
dataset="homog_power_grids_20"
# dataset="homog_power_grids_100"


device = CUDA.functional() ? Flux.gpu : Flux.cpu;


include(string(package_dir,"/src/training.jl"))


if occursin("homog_power_grids",dataset)
    include(string(package_dir,"/src/snbs_homo.jl"))
    if dataset == "homog_power_grids_20"
        ds_dir = "dataset/ds20"
    elseif dataset == "homog_power_grids_100"
        ds_dir = "dataset/ds100"
    end
    use_edge_features = false
    train_data, valid_data, test_data = read_snbs_homogeneous_datasets(ds_dir,"snbs", use_edge_features)
    data = DATA(train_data, valid_data, test_data)
end

num_samples = 5 
samples_start = 1
samples_end = 1 
num_epochs = 10000
save_model_crit = "R2" 
loss_function = "MSE"
model_name = "DBGC_Model"
hyper_study_name = "hyper_tmp"
dhidden_n = 500 
dhidden_e = 10 
dropout_n = 6.1e-2# 0#.03 
dropout_e = 0#2e-2#0#.03
batchnorm_n = false
batchnorm_e = false
init_ud = "normal_distribution"
init_lindb = "normal_distribution"
init_dense = "kaiming_normal"
# dense_after_conv_dim = 500 
seed = 1
Δ = .01f0
DB_balanced = false
num_dbl = 4
steps_dbl = 12
use_GAT = false 
GAT_heads = 10
GAT_bias = true
GAT_concat = true
GAT_negative_slope = .02f0
dense_after_db_dim= false #250
dense_after_gat_dim= false#250
GAT_output_n_dim = 500
pool = false #Flux.max
target_classes = 1

# opt properties
opt_name = "ADAMW"
opt_LR =  0.00035938136638046257#1e-3#0.001291549665014884#7.29e-6#1e-5
opt_decay_mom = (0.9, 0.999)
opt_weight_decay = 1e-9

use_LR_decay = true 
LR_decay = .65 
LR_decay_step =  14000
LR_clip = 1e-5
gradient_clipping = false#1e-1

# early stopping criteria
use_early_stopping = false#true
mode_max = set_early_stopping_mode(save_model_crit)
grace_period = 1000
early_stopping_threshold = 0.5
early_stopping = early_stopping_struct(use_early_stopping,mode_max,grace_period,early_stopping_threshold)

skip_connection_n = false
skip_connection_e = false

skip_connection_n_share_weights = true
skip_connection_e_share_weights = true

loaders_props = dataloader_properties(250, 250,   250,    true,    false,    false,     true, true, true)
opt_props = opt_properties(opt_name,opt_LR, opt_decay_mom, opt_weight_decay, use_LR_decay,LR_decay, LR_decay_step, LR_clip, gradient_clipping)

model_props = model_properties_DBGC_Model(model_name, dhidden_n, dhidden_e, num_dbl, steps_dbl, skip_connection_n, skip_connection_e, skip_connection_n_share_weights, skip_connection_n_share_weights, dropout_n, dropout_e, DB_balanced, dense_after_db_dim, Δ, pool, target_classes, init_ud, init_lindb, init_dense)

training_props = training_properties(num_epochs,true,loss_function, early_stopping)
hyper_params = parameter_settings(num_samples,samples_start, samples_end,save_model_crit,model_name,hyper_study_name,opt_props,model_props,training_props,loaders_props,seed)

@time res = train_one_configuration(data, hyper_params, samples_start)
save_results(res, hyper_params, samples_start)
