using BSON: @load, @save
using Metrics


"""
    struct performance
    
    Container for the performances
"""
mutable struct performance
    loss::Any
    perf::Any
end

"""
eval_model(model, g, use_e_data)

Eval model

- `model`: model
- `g`: GNNGraph
- `use_e_data`: bool true if edge data is considered
"""
function eval_model(model, g, use_e_data)
    if use_e_data
        return model(g, g.ndata.features, g.edata.features) |> vec
        # return model(g,g.node_data.features,g.edge_data.features)
        # return model(g, g.ndata.x, ones(1, g.num_edges)) |> vec
    else
        return model(g, g.ndata.features) |> vec
    end
end


"""
get_predictions_targets(model, data_loader, use_e_data)

Get predictions and targets

- `model`: model
- `data_loader`: data_loader
- `use_e_data`: bool true if edge data is considered
"""
function get_predictions_targets(model, data_loader, use_e_data)
    all_y = Array{Int32}(undef, 0)
    all_ŷ = Array{Float32}(undef, 0)
    for (g, y) in data_loader
        all_y = vcat(all_y, reshape(y, :, 1))
        g, y = (g, y) |> device
        ŷ = eval_model(model, g, use_e_data) |> vec |> cpu
        all_ŷ = vcat(all_ŷ, ŷ)
    end
    return all_ŷ, all_y
end


"""
train(model, train_loader, valid_loader, test_loader, ps, opt, num_epochs, use_e_data, model_name=false, save_model_criterion=false)

Train model

- `model`: model
- `train_loader`: train_loader
- `valid_loaders`: valid_loaders
- `test_loaders`: test_loaders
- `ps`: Flux parameters of model
- `opt`: Optimizer
- `loss_function`: Loss function 
- `num_epochs`: number of epochs
- `use_e_data`: bool true if edge data is considered
- `model_name`: name of model
- `save_model_criterion`: use the following performance criterion for valid evaluation to decide if the model outperforms previous runs and should be stored
"""
function train(model, train_loader, valid_loader, test_loader, ps, opt, loss_function, num_epochs, use_e_data, early_stopping, model_name=false, save_model_criterion=false)
    if model_name != false
        save_model = true
        if save_model_criterion == "R2"
            best_val_perf = -1E10
            opt_mode = "max"
        elseif save_model_criterion == "MAE" && model_name != false
            best_val_perf = 1E10
            opt_mode = "min"
        end
    else
        save_model = false
    end
    train_loss = Array{Float64}(undef, num_epochs)
    valid_loss = Array{Float64}(undef, num_epochs)
    test_loss = Array{Float64}(undef, num_epochs)
    train_perf = Array{Float64}(undef, num_epochs)
    valid_perf = Array{Float64}(undef, num_epochs)
    test_perf = Array{Float64}(undef, num_epochs)
    stop_training = false
    for epoch = 1:num_epochs
        if !stop_training
            # trainmode!(model)
            for (g, y) in train_loader
                g, y = (g, y) |> device
                if loss_function == "MSE"
                    gs = Flux.gradient(ps) do
                        ŷ = eval_model(model, g, use_e_data)
                        Flux.mse(ŷ, reshape(y, :, 1))
                    end
                elseif loss_function == "MAE"
                    gs = Flux.gradient(ps) do
                        ŷ = eval_model(model, g, use_e_data)
                        Flux.mae(ŷ, reshape(y, :, 1))
                    end
                end
                Flux.Optimise.update!(opt, ps, gs)
            end
            report_res = report(epoch, model, train_loader, valid_loader, test_loader, use_e_data)
            train_loss[epoch] = report_res[1].loss
            train_perf[epoch] = report_res[1].perf
            valid_loss[epoch] = report_res[2].loss
            valid_perf[epoch] = report_res[2].perf
            test_loss[epoch] = report_res[3].loss
            test_perf[epoch] = report_res[3].perf
            if save_model
                best_val_perf = check_performance_save_model(best_val_perf, valid_perf[epoch], model, model_name, opt_mode, epoch)
            end
            if isnan(report_res[1].loss)
                println("################################################")
                println("train loss is NaN, set epoch to num_epochs")
                stop_training = true
            end
            stop_early = check_early_stopping(best_val_perf, epoch, early_stopping)
            if stop_early
                println("################################################")
                println("apply early stopping after ", epoch, " epochs, with best_val_perf = ", best_val_perf)
                stop_training = true
                train_loss[epoch:num_epochs] .= report_res[1].loss
                train_perf[epoch:num_epochs] .= report_res[1].perf
                valid_loss[epoch:num_epochs] .= report_res[2].loss
                valid_perf[epoch:num_epochs] .= report_res[2].perf
                test_loss[epoch:num_epochs] .= report_res[3].loss
                test_perf[epoch:num_epochs] .= report_res[3].perf
            end   
        end    
    end
    train_res = performance(train_loss, train_perf)
    valid_res = performance(valid_loss, valid_perf)
    test_res = performance(test_loss, test_perf)
    return train_res, valid_res, test_res
end

"""
save_model(model, model_name)

Train model

- `model`: model
- `model_name`: model_name
- `epoch`: epoch
"""
function save_model(model, model_name, valid_perf, epoch)
    file_model = string("models/", model_name, ".bson")
    model_cpu = model |> cpu
    @save file_model model_cpu
    println("The model: ", model_name, " is saved after the following epoch: ", string(epoch), " with valid perf: ", round(valid_perf, digits=2))
end
"""
check_performance_save_model(best_val_perf, valid_perf, model, model_name, mode)

Check performance and save if model is improved

- `best_val_perf`: current best performance on validation set
- `valid_perf`: performance on validation set after current epoch
- `model`: model
- `model_name`: model_name
- `mode`: "min" or "max"
- `epoch`: epoch

"""
function check_performance_save_model(best_val_perf, valid_perf, model, model_name, mode, epoch)
    if mode == "max"
        if best_val_perf < valid_perf
            best_val_perf = valid_perf
            save_model(model, model_name, valid_perf, epoch)
        end
    elseif mode == "min"
        if best_val_perf > valid_perf
            best_val_perf = valid_perf
            save_model(model, model_name, valid_perf, epoch)
        end
    end
    return best_val_perf
end


        
