using GraphNeuralNetworks, Graphs

using Statistics

using BSON: @load, @save

using MLDatasets, DataFrames, Pickle

using MLUtils

using MultivariateAnomalies

using HDF5

using Random

"""
    struct dataloader_properties
    
    Container for the properties of the simulation
"""
mutable struct dataloader_properties
    train_bs::Int64
    valid_bs::Int64
    test_bs::Int64
    train_shuffle::Bool
    valid_shuffle::Bool
    test_shuffle::Bool
    train_collate::Bool
    valid_collate::Bool
    test_collate::Bool
end
"""
    struct data
    
    Container for train, valid, test data
"""
mutable struct DATA
    train::Any
    valid::Any
    test::Any
end

"""
    struct DATA_single_graph_ntask
    
    Container for graph, labels
"""
mutable struct DATA_single_graph
    graph
    y
end


"""
    struct DATA_single_graph_ntask_mask
    
    Container for graph, labels, mask
"""
mutable struct DATA_single_graph_mask
    graph
    y
    mask
end


"""
    struct results
    
    Container for the results of train, valid, test
"""
mutable struct results
    train
    valid
    test
end
"""
    struct opt_properties
    
    Container for the properties of optimizer
"""
mutable struct opt_properties
    name::String
    LR::Float32
    decay_momentum::Tuple
    weight_decay::Float32
    use_LR_decay::Bool
    LR_decay::Float32
    LR_decay_step::Int
    LR_clip::Float32
    gradient_clipping
end

"""
    struct model_properties_GR
    
    Container for the properties of the (old) model DBEquation_GR 
"""
mutable struct model_properties_GR
    name::String
    d_hidden::Int32
    deep::Int32
    dropout_n::Float32
    dropout_e::Float32
    batchnorm_n::Any
    batchnorm_e::Any
    balanced::Bool
    dim_dense_after_conv::Any
    dim_dense_after_pool::Any
    pool_criterion::Any
end
"""
    struct model_properties_GR
    
    Container for the properties of the (old) model DBEquation_NR 
"""
mutable struct model_properties_NR
    name::String
    d_hidden::Int32
    deep::Int32
    dropout_n::Float32
    dropout_e::Float32
    batchnorm_n::Any
    batchnorm_e::Any
    balanced::Bool
    dim_dense_after_conv::Any
    Δ::Float32
end

"""
    struct model_properties_DBGAT_Model
    
    Container for the properties of the model DBGAT_Model 
"""
mutable struct model_properties_DBGAT_Model
    name::String
    deep::Int32
    d_hidden_n::Int32
    d_hidden_e::Int32
    dropout_n::Float32
    dropout_e::Float32
    balanced::Bool
    dim_dense_after_db
    dim_dense_after_gat
    Δ::Float32
    use_GAT::Bool
    GAT_heads::Int
    GAT_bias::Bool
    GAT_concat::Bool
    GAT_negative_slope::Float32
    GAT_output_n_dim
    pool
    target_classes::Int32
end

"""
    struct model_properties_DBGAT_Model
    
    Container for the properties of the model DBGAT_Model 
"""
mutable struct model_properties_DBmultGAT_Model
    name::String
    d_hidden_n::Int32
    d_hidden_e::Int32
    num_dbl::Int32
    steps_dbl::Int32
    dropout_n::Float32
    dropout_e::Float32
    balanced::Bool
    dim_dense_after_db
    dim_dense_after_gat
    Δ::Float32
    use_GAT::Bool
    GAT_heads::Int
    GAT_bias::Bool
    GAT_concat::Bool
    GAT_negative_slope::Float32
    GAT_output_n_dim
    pool
    target_classes::Int32
end


"""
    struct model_properties_DBGC_Model
    
    Container for the properties of DBGC_Model 
"""
mutable struct model_properties_DBGC_Model
    name::String
    d_hidden_n::Int32
    d_hidden_e::Int32
    num_dbl::Int32
    steps_dbl::Int32
    skip_connection_n::Bool
    skip_connection_e::Bool
    skip_connection_n_share_weights::Bool
    skip_connection_e_share_weights::Bool
    dropout_n::Float32
    dropout_e::Float32
    balanced::Bool
    dim_dense_after_db
    Δ::Float32
    pool
    target_classes::Int32
    init_ud::String
    init_lindb::String
    init_dense::String
end

"""
    struct model_properties_DBGNN
    
    Container for the properties of DBGNN 
"""
mutable struct model_properties_DBGNN
    name::String
    d_hidden_n::Int32
    d_hidden_e::Int32
    num_dbl::Int32
    steps_dbl::Int32
    skip_connection_n::Bool
    skip_connection_e::Bool
    skip_connection_n_share_weights::Bool
    skip_connection_e_share_weights::Bool
    dropout_n::Float32
    dropout_e::Float32
    balanced::Bool
    dense_after_linDB
    Δ::Float32
    pool
    target_classes::Int32
    init_dense_in::String
    init_lindb::String
    init_dense::String
end

"""
    struct_early_stopping
#Arguments
- `use`: Bool, if true, use early stopping
- `mode_max:` Bool, if true the criterion is maximized
- `grace_period:` Int32: num_epochs for which no early_stopping is used
- `threshold:` Float32, threshold for early stopping
"""    
mutable struct early_stopping_struct
    use::Bool
    mode_max::Bool
    grace_period::Int32
    threshold::Float32
end

"""
    struct training_properties
    
    Container for the training properties
"""
mutable struct training_properties
    epochs::Int32
    use_e_data::Bool
    loss_function
    early_stopping::early_stopping_struct
end

"""
    struct parameter_settings
    
    Properties of hyperparameters
"""
mutable struct parameter_settings
    num_samples::Int32
    samples_start::Int32
    samples_end::Int32
    save_model_crit
    model_name::String
    hyper_study_name::String
    opt_props::opt_properties
    model_props::Any
    training_props::training_properties
    loader_props::dataloader_properties
    seed::Int32
end


"""
set_dataloaders(data::DATA, props::dataloader_properties)

Generates data_loaders

- `data::DATA`: struct containing train, valid, test data
- `props::dataloder_properties`: struct containing the dataloder properties
"""
function set_dataloaders(data::DATA, props::dataloader_properties)
    train_loader = DataLoader(data.train; batchsize=props.train_bs, shuffle=props.train_shuffle, collate=props.train_collate)
    valid_loader = DataLoader(data.valid; batchsize=props.valid_bs, shuffle=props.valid_shuffle, collate=props.valid_collate)
    test_loader = DataLoader(data.test; batchsize=props.test_bs, shuffle=props.test_shuffle, collate=props.test_collate)
    return train_loader, valid_loader, test_loader
end
"""
set_opt(props::opt_properties)

Generates the optimizer

- `opt_properties`: struct containing the optimizer properties
"""
function set_opt(props)
    if props.name == "ADAM"
        opt = ADAM(props.LR)
    elseif props.name == "ADAMW"
        opt = AdamW(props.LR, props.decay_momentum, props.weight_decay)
    end
    if props.use_LR_decay == true
        expDecay = ExpDecay(props.LR, props.LR_decay, props.LR_decay_step, props.LR_clip)
        opt = Flux.Optimiser(expDecay, opt)
    end
    if props.gradient_clipping != false
        gradient_clipping = ClipValue(props.gradient_clipping)
        opt = Flux.Optimiser(gradient_clipping, opt)
    end
    return opt
end

"""
set_model(props, g, seed=1)

generates the model

- `props::model_properties`: struct containing the model properties
- `g`: graph
-`seed`: seed
"""
function set_model(props, g, init_dense_in=Flux.randn32, init_lindb = Flux.randn32, init_dense = Flux.kaiming_normal)
    init_ud = init_dense_in # needed for older models
    n_features = g.ndata.features
    nin = size(n_features, 1)
    ein = 1
    try 
        e_features = g.edata.features
        ein = size(e_features, 1)
    catch
        ein = 1
        println("Dimension of edge features is set to 1")
    end
    dims_n = [nin, props.d_hidden_n]
    dims_e = [ein, props.d_hidden_e]

    if skip_connection_n == false
        dim_hidden_n = props.d_hidden_n
    else
        dim_hidden_n = props.d_hidden_n + nin 
    end
    if skip_connection_e == false
        dim_hidden_e = props.d_hidden_e
    else
        dim_hidden_e = props.d_hidden_e + ein
    end

    if props.name == "DBEquation_GR"
        return DBModule.DBEquation_GR(dims_n, dims_e, props.deep; balanced=props.balanced, dropout_n=props.dropout_n, dropout_e=props.dropout_e, batchnorm_n=props.batchnorm_n, batchnorm_e=props.batchnorm_e, dense_after_conv_dim=props.dim_dense_after_conv, dense_after_pool_dim=props.dim_dense_after_pool, pool_criterion=props.pool_criterion, Δ=props.Δ)
    elseif props.name == "DBEquation_NR"
        return DBModule.DBEquation_NR(dims_n, dims_e, props.deep; balanced=props.balanced, dropout_n=props.dropout_n, dropout_e=props.dropout_e, batchnorm_n=props.batchnorm_n, batchnorm_e=props.batchnorm_e, dense_after_conv_dim=props.dim_dense_after_conv, Δ=props.Δ)
    elseif props.name == "DBGAT_Model"
        return DBModule.DBGAT_Model(dims_n, dims_e, props.deep; balanced=props.balanced, dropout_n=props.dropout_n, dropout_e=props.dropout_e, dense_after_db_dim=props.dim_dense_after_db, dense_after_gat_dim=props.dim_dense_after_gat, Δ=props.Δ, use_GAT=props.use_GAT, GAT_heads=props.GAT_heads, GAT_bias=props.GAT_bias, GAT_concat=props.GAT_concat, GAT_negative_slope=props.GAT_negative_slope, GAT_output_n_dim=props.GAT_output_n_dim, pool=props.pool, target_classes = props.target_classes)
    elseif props.name == "DBmultGAT_Model"
        return DBModule.DBmultGAT_Model(dims_n, dims_e, props.num_dbl, props.steps_dbl; balanced=props.balanced, dropout_n=props.dropout_n, dropout_e=props.dropout_e, dense_after_db_dim=props.dim_dense_after_db, dense_after_gat_dim=props.dim_dense_after_gat, Δ=props.Δ, use_GAT=props.use_GAT, GAT_heads=props.GAT_heads, GAT_bias=props.GAT_bias, GAT_concat=props.GAT_concat, GAT_negative_slope=props.GAT_negative_slope, GAT_output_n_dim=props.GAT_output_n_dim, pool=props.pool, target_classes = props.target_classes, seed=seed)
    elseif props.name == "DBGC_Model"
        return DBModule.DBGCModel(dims_n, dims_e, dim_hidden_n, dim_hidden_e, props.num_dbl, props.steps_dbl; balanced=props.balanced, skip_connection_n = props.skip_connection_n, skip_connection_e = props.skip_connection_e, skip_connection_n_share_weights = props.skip_connection_n_share_weights, skip_connection_e_share_weights = skip_connection_e_share_weights, dropout_n=props.dropout_n, dropout_e=props.dropout_e, dense_dim=props.dim_dense_after_db, Δ=props.Δ, pool=props.pool, target_classes = props.target_classes, init_ud = init_ud, init_lindb = init_lindb, init_dense = init_dense)
    elseif props.name == "DBGNN"
        return DBModule.DBGNN(nin, dim_hidden_n, props.target_classes, ein, dim_hidden_e, props.num_dbl, props.steps_dbl; Δ = Δ, skip_connection_n=props.skip_connection_n, skip_connection_e=props.skip_connection_e, skip_n_share_weight=skip_connection_n_share_weights,  skip_e_share_weight=skip_connection_e_share_weights, final_pool=props.pool, dense_after_linDB=props.dense_after_linDB, balanced=props.balanced, dropout_n=props.dropout_n, dropout_e=props.dropout_e, init_dense_in = init_dense_in, init_lindb = init_lindb, init_dense = init_dense)
    end
end

""" make_init_function(string_init, seed)

To initialize the functions using different methods and seeds based on a string input

- `string_init`: string of init method
- `seed`: seed for random number generator
"""
function make_init_function(string_init, seed)
    if string_init == "kaiming_normal"
        return Flux.kaiming_normal(seed)
    elseif string_init == "kaiming_uniform"
        return Flux.kaiming_uniform(seed)
    elseif string_init == "normal_distribution"
        return Flux.randn32(seed)
    elseif string_init == "glorot_uniform"
        return Flux.glorot_uniform(seed)
    elseif string_init == "glorot_normal"
        return Flux.glorot_normal(seed)
    elseif string_init == "zeros"
        println("Setting zeros for initialization might not work properly.")
        return "zeros"
    else
        println("Init method not implemented")
    end
end


"""
    train_one_configuration(data::DATA,props::parameter_settings, sample_run, load_model_name=false)

Full cyclus of one configuration including, training, validation and potentially saving results

- `data::DATA`: struct containing the data
- `props::parameter_settings`: struct containing all the configuration parameters
- `sample_run`: idx of run
- `load_model_name`: name of the model to be restored, set false to generate new model
"""
function train_one_configuration(data, props::parameter_settings, sample_run, load_model_name=false)
    loaders_props = props.loader_props
    opt_props = props.opt_props
    model_props = props.model_props
    training_props = props.training_props

    seed = props.seed
    seed = MersenneTwister(seed)
    init_dense_in_str = false
    try
        init_dense_in_str = model_props.init_dense_in
    catch
        init_dense_in_str = model_props.init_ud
    end
    init_dense_in = make_init_function(init_dense_in_str, seed)
    init_lindb = make_init_function(model_props.init_lindb, seed)
    init_dense = make_init_function(model_props.init_dense, seed)
    Random.seed!(seed)
    
    if typeof(data) == DATA
        train_loader, valid_loader, test_loader = set_dataloaders(data, loaders_props)
    end
    opt = set_opt(opt_props)
    if load_model_name == false
        if typeof(data) == DATA
            model = set_model(model_props, train_loader.data[1][1], init_dense_in, init_lindb, init_dense)
        else
            model = set_model(model_props, data.graph, init_function)
        end
        model = model |> device
    else
        @load load_model_name model_cpu
        println("model loaded, file: ", load_model_name)
        model = model_cpu |> device
    end
    ps = Flux.params(model)
    save_model_crit = props.save_model_crit
    if save_model_crit == false
        model_name = false
        save_model_crit = false
    else
        model_name = string(props.model_name, "_", props.hyper_study_name, "_sample_", sample_run, "_")
    end
    if typeof(data) == DATA
        res_tr, res_val, res_test = train(model, train_loader, valid_loader, test_loader, ps, opt, training_props.loss_function, training_props.epochs, training_props.use_e_data, training_props.early_stopping, model_name, save_model_crit)
    else
        res_tr, res_val, res_test = train(model, data, ps, opt, training_props.loss_function, training_props.epochs, training_props.use_e_data, training_props.early_stopping, model_name, save_model_crit)
    end
    return results(res_tr, res_val, res_test)
end
"""
save_results(res::results, props::parameter_settings, sample_run)

Saving results

- `res::results`: struct containing the results
- `props::parameter_settings`: struct containing all the configuration parameters
- `sample_run`: idx of run
"""
function save_results(res::results, props::parameter_settings, sample_run)
    dir_path = props.hyper_study_name
    if isdir(dir_path) == false
        mkdir(dir_path)
    end

    filename = string(dir_path, "/performances.hdf5")
    if isfile(filename)
        fid_files = h5open(filename, "r+")
    else
        fid_files = h5open(filename, "w")
    end

    sample_group = create_group(fid_files, string(sample_run))
    train_group = create_group(sample_group, "train")
    valid_group = create_group(sample_group, "valid")
    test_group = create_group(sample_group, "test")

    train_group = make_groups_performances(train_group, res.train)
    valid_group = make_groups_performances(valid_group, res.valid)
    test_group = make_groups_performances(test_group, res.test)
    close(fid_files)
end



"""
study_opt_lr(data, params::parameter_settings, lr, load_model_name=false)

Hyperparameter study varying the learning rate

- `data`: struct containing the data
- `props::parameter_settings`: struct containing all the configuration parameters
- `lr`: array with learning rates
- `load_model_name`: name of the model to be restored, set false to generate new model
"""
function study_opt_lr(data, params::parameter_settings, lr, load_model_name=false)
    for sample_run in params.samples_start:params.samples_end
        println("sample idx: ", sample_run)
        params.opt_props.LR = lr[sample_run]
        @time res = train_one_configuration(data, params, sample_run, load_model_name)
        save_results(res, params, sample_run)
    end
end

"""
study_dropout(data, params::parameter_settings, dropout_values_n, dropout_values_e)

Hyperparameter study varying the learning rate

- `data`: struct containing the data
- `props::parameter_settings`: struct containing all the configuration parameters
- `dropout_values_n`: array with dropout values for the nodes
- `dropout_values_e`: array with dropout values for the edges
"""
function study_dropout(data, params::parameter_settings, dropout_values_n, dropout_values_e)
    for sample_run in params.samples_start:params.samples_end
        println("sample idx: ", sample_run)
        params.model_props.dropout_n = dropout_values_n[sample_run]
        params.model_props.dropout_e = dropout_values_e[sample_run]
        @time res = train_one_configuration(data, params, sample_run)
        save_results(res, params, sample_run)
    end
end


"""
study_different_init_seeds(data, model_params::parameter_settings, seeds)

Hyperparameter study varying the learning rate

- `data`: struct containing the data
- `props::parameter_settings`: struct containing all the configuration parameters
- `seeds`: array with seeds
"""
function study_init_seeds(data, params, seeds)
    for sample_run in params.samples_start:params.samples_end
        println("sample idx: ", sample_run)
        params.seed = seeds[sample_run]
        @time res = train_one_configuration(data, params, sample_run)
        save_results(res, params, sample_run)
    end
end


"""
load_model(path_model_file)

load model

- `path_model_file`: path to model that needs to be loaded
"""
function load_model(path_model_file)
    @load path_model_file model
    return model
end

""" 
    check_early_stopping(best_val_perf, epoch, early_stopping)
To reduce unnecesssary training time, the training may be stopped early, if the performance is considered to be low
#Arguments
- `best_val_perf`: best obtained validation performance
- `epoch`: current epoch of training
- `early_stopping`: struct contain information on early stopping
"""
function check_early_stopping(best_val_perf, epoch, early_stopping)
    if early_stopping.use == true && epoch >= early_stopping.grace_period
        if early_stopping.mode_max == true
            if best_val_perf < early_stopping.threshold
                return true
            end
        elseif early_stopping.mode_max == false
            if best_val_perf > early_stopping.threshold
                return true
            end
        end
    end
    return false
end

"""
    set_early_stopping_mode(save_model_crit)
To set the early stopping mode by considering save_model_crit
#Arguments
- `save_model_crit`: string, indicating the performance metric for saving the model
"""
function set_early_stopping_mode(save_model_crit)
    if save_model_crit in ["R2", "accu", "auc"]
        return true
    elseif save_model_crit in ["MAE"]
        return false
    else
        print("Can not set early stopping mode, invalid save_model_crit")
    end
end


"""
    shuffle_graphs_labels(graphs, labels, seed)
Shuffles graphs and labels

- `graphs`: list of graphs
- `labels`: list of labels
- `seed`: seed
"""
function shuffle_graphs_labels(graphs, labels, seed)
    num_graphs = length(graphs)
    new_graphs = similar(graphs)
    new_labels = similar(labels)
    random_idx = shuffle(MersenneTwister(seed),1:num_graphs)
    for i in 1:num_graphs
        j = random_idx[i]
        new_graphs[i] = graphs[j]
        new_labels[i] = labels[j]
    end
    return new_graphs, new_labels
end


"""
    reduce_train_data(data,train_share, seed)
    - data:: DATA
    - train_share::float32
    - seed::Int
"""
function reduce_train_data!(data, train_share, seed)
    length_orig_train = length(data.train[1])
    random_idx = shuffle(MersenneTwister(seed),1:length_orig_train)
    max_idx = floor(Int,train_share * length(random_idx))
    new_train_graphs = data.train[1][random_idx][1:max_idx]
    new_train_labels = data.train[2][random_idx][1:max_idx]
    new_train = new_train_graphs, new_train_labels
    data.train = new_train
end
