# =================================================================================================#
# Description: Metrics for evaluating performance
# Author: Ryan Thompson
# =================================================================================================#

# Create Brier score function
function brier_score(w, ŵ)
    probs = sum(ŵ .≠ 0, dims = 3) / size(ŵ, 3)
    w_nonzero = w .≠ 0
    sum((w_nonzero .- probs) .^ 2)
end

# Create structural Hamming distance function
function struct_hamming_dist(w, ŵ)
    w_active = w .≠ 0
    ŵ_active = ŵ .≠ 0
    diff = abs.(w_active - ŵ_active)
    diff = diff + transpose(diff)
    diff[diff .> 1] .= 1
    sum(diff) / 2
end

# Create F1 score function
function f1_score(w, ŵ)
    tp = sum((w .≠ 0) .& (ŵ .≠ 0))
    fp = sum((w .== 0) .& (ŵ .≠ 0))
    fn = sum((w .≠ 0) .& (ŵ .== 0))
    if tp == 0 && fp == 0 && fn == 0
        1.0
    else
        2 * tp / (2 * tp + fp + fn)
    end
end

# Create AUROC function
function binary_auroc(w, ŵ)
    probs = sum(ŵ .≠ 0, dims = 3)[:, :, 1] / size(ŵ, 3)
    w_nonzero = Int64.(w .≠ 0)
    torchmetrics = PyCall.pyimport("torchmetrics")
    torch = PyCall.pyimport("torch")
    numpy = PyCall.pyimport("numpy")
    metric = torchmetrics.classification.BinaryAUROC()
    numpy.array(metric(torch.tensor(probs), torch.tensor(w_nonzero)))[1]
end