include("../src/nodal_regression.jl")


using HDF5

using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics, MLUtils
using Flux: DataLoader
# 
# """
#     struct performance
    
#     Container for the performances
# """
# mutable struct performance
#     loss::Any
#     r2::Any
# end

"""
read_snbs_homogeneous_datasets(path, task, use_edge_features=true)

    Read train, valid, test SNBS datasets
- `path`: directory of dataset
- `task`: task: "SNBS" or "TM"
- `use_edge_features`: bool true if power flow should be used as edge features
"""
function read_snbs_homogeneous_datasets(path, task, use_edge_features=true)
    train_data = read_one_ds(string(path, "/train"), task, use_edge_features)
    valid_data = read_one_ds(string(path, "/valid"), task, use_edge_features)
    test_data = read_one_ds(string(path, "/test"), task, use_edge_features)
    return train_data, valid_data, test_data
end

"""
read_one_ds(path, task, use_edge_feattures=true)

    Read one SNBS dataset from HDF5 format
- `path`: directory of dataset
- `task`: task: "SNBS" or "TM"
- `use_edge_features`: bool true if power flow should be used as edge features
"""
function read_one_ds(path, task, use_edge_feattures=true)
    input_data_name = string(path, "/input_data.h5")
    fid = h5open(input_data_name, "r")
    grids = read(fid, "grids")
    close(fid)
    target_name = string(path, "/", task, ".h5")
    fid_target = h5open(target_name, "r")
    targets = read(fid_target)
    close(fid_target)
    all_graphs = GNNGraph[]
    all_targets = Array{Array{Float64}}(undef, 0)
    for (key, one_g) in grids
        edge_index, edge_attr, node_features = one_g["edge_index"], one_g["edge_attr"], one_g["node_features"]
        s = edge_index[1, :]
        t = edge_index[2, :]
        # g = GNNGraph(s,t,edge_attr,ndata = node_features)
        if use_edge_feattures
            g = GNNGraph(s, t, ndata=(; features=Float32.(unsqueeze(node_features, 1))), edata=(; features=Float32.(unsqueeze(edge_attr, 1))))
        else
            g = GNNGraph(s, t, ndata=(; features=Float32.(unsqueeze(node_features, 1))), edata=(; features=Float32.(ones(size(unsqueeze(edge_attr, 1))))))
        end
        push!(all_graphs, g)
        target = targets[string(key)]
        push!(all_targets, target)
    end
    return all_graphs, all_targets
end


"""
eval_loss_r2(model, data_loader, use_e_data)

Compute loss and r2

- `model`: model
- `data_loader`: data_loader
- `use_e_data`: bool true if edge data is considered
"""
function eval_loss_r2(model, data_loader, use_e_data)
    all_ŷ, all_y = get_predictions_targets(model, data_loader, use_e_data)
    loss = Flux.mse(all_ŷ, all_y)
    r2 = r2_score(all_ŷ, all_y)
    return performance(round(loss, digits=4), Float32(round(r2 * 100, digits=2)))
end

"""
report(epoch, model, train_loader, valid_loader, test_loader, use_e_data)

Report performance on train, valid and test set

- `epoch`: epoch
- `model`: model
- `train_loader`: train_loader
- `valid_loaders`: valid_loaders
- `test_loaders`: test_loaders
- `use_e_data`: bool true if edge data is considered
"""
function report(epoch, model, train_loader, valid_loader, test_loader, use_e_data)
    train_performance = eval_loss_r2(model, train_loader, use_e_data)
    valid_performance = eval_loss_r2(model, valid_loader, use_e_data)
    test_performance = eval_loss_r2(model, test_loader, use_e_data)
    println("Epoch: $epoch   Train: $(train_performance)   Test: $(test_performance)")
    return train_performance, valid_performance, test_performance
end

# """
# train(model, train_loader, valid_loader, test_loader, ps, opt, num_epochs, use_e_data, model_name=false, save_model_criterion=false)

# Train model

# - `model`: model
# - `train_loader`: train_loader
# - `valid_loaders`: valid_loaders
# - `test_loaders`: test_loaders
# - `ps`: Flux parameters of model
# - `opt`: Optimizer
# - `num_epochs`: number of epochs
# - `use_e_data`: bool true if edge data is considered
# - `model_name`: name of model
# - `save_model_criterion`: use the following performance criterion for valid evaluation to decide if the model outperforms previous runs and should be stored
# """
# function train(model, train_loader, valid_loader, test_loader, ps, opt, num_epochs, use_e_data, model_name=false, save_model_criterion=false)
#     if save_model_criterion == "R2" && model_name != false
#         save_model = true
#         best_val_perf = -1E10
#     else
#         save_model = false
#     end
#     train_loss = Array{Float64}(undef, num_epochs)
#     valid_loss = Array{Float64}(undef, num_epochs)
#     test_loss = Array{Float64}(undef, num_epochs)
#     train_r2 = Array{Float64}(undef, num_epochs)
#     valid_r2 = Array{Float64}(undef, num_epochs)
#     test_r2 = Array{Float64}(undef, num_epochs)
#     for epoch = 1:num_epochs
#         # trainmode!(model)
#         for (g, y) in train_loader
#             g, y = (g, y) |> device
#             gs = Flux.gradient(ps) do
#                 ŷ = eval_model(model, g, use_e_data)
#                 Flux.mse(ŷ, reshape(y, :, 1))
#             end
#             Flux.Optimise.update!(opt, ps, gs)
#         end
#         report_res = report(epoch, model, train_loader, valid_loader, test_loader, use_e_data)
#         train_loss[epoch] = report_res[1].loss
#         train_r2[epoch] = report_res[1].r2
#         valid_loss[epoch] = report_res[2].loss
#         valid_r2[epoch] = report_res[2].r2
#         test_loss[epoch] = report_res[3].loss
#         test_r2[epoch] = report_res[3].r2
#         if save_model
#             if best_val_perf < valid_r2[epoch]
#                 best_val_perf = valid_r2[epoch]
#                 # file_model = string("models/", model_name, "_epoch_", epoch, ".bson")
#                 file_model = string("models/", model_name, ".bson")
#                 model_cpu = model |> cpu
#                 @save file_model model_cpu
#                 println("The model: ", model_name, " is saved after the following epoch: ", string(epoch), " with valid r2: ", round(valid_r2[epoch], digits=2))
#             end
#         end
#     end
#     train_res = performance(train_loss, train_r2)
#     valid_res = performance(valid_loss, valid_r2)
#     test_res = performance(test_loss, test_r2)
#     return train_res, valid_res, test_res
# end


"""
make_groups_performances(group, res)

group the performances

- `group`: group for performance measures
- `res`: results
"""
function make_groups_performances(group, res::performance)
    group["loss"] = res.loss
    group["R2"] = res.perf
    return group
end