include("../WFD/WFD.jl")
include("../WFD/common.jl")
include("../WFD/pagerank.jl")

using ArgParse
using Random
using Printf
using DelimitedFiles
using Statistics
using LinearAlgebra
using HDF5
using .GC

using ScikitLearn
@sk_import linear_model: LogisticRegression


function parse_commandline_args()
    s = ArgParseSettings()

    @add_arg_table! s begin
        "--dataset"
        help = "Specify the dataset (e.g. data/coauthor_cs)"
        "--output"
        help = "Specify the filename"
        "--num_data"
        arg_type = Int
        default = 10
        help = "Specify the number of data"
        "--num_trials"
        arg_type = Int
        default = 100
        help = "Specify the number of trials"
    end

    return parse_args(s)
end

parsed_args = parse_commandline_args()

# Access the values of the arguments
dataset = get(parsed_args, "dataset", "")
output = get(parsed_args, "output", "")
num_data = get(parsed_args, "num_data", 10)
num_trials = get(parsed_args, "num_trials", 100)

labels = vec(readdlm(dataset*"/labels.txt", ' ', Int64, '\n'))
num_classes = length(unique(labels))

@printf("number of classes: %d\n", num_classes)

ll = Vector{Vector{Int64}}()
for line in readlines(dataset*"/adj.txt")
    if isempty(strip(line))
        push!(ll,[])
        continue
    end
    push!(ll,[parse(Int64, i) for i in split(strip(line), " ")])
end
degree = [length(l) for l in ll]
n = length(ll)
@printf("number of nodes: %d\n", n)

degree = [length(l) for l in ll]
G = AdjacencyList(ll, degree, n)
cc = connected_components(G)
sizes = [length(c) for c in cc]
largest = argmax(sizes)
@printf("number of nodes in the largest component: %d\n", length(cc[largest]))

for i in 1:largest-1
    labels[cc[i]] .= -1
end
for i in largest+1:length(cc)
    labels[cc[i]] .= -1
end

eps = 0.05
alpha_range = [0.01, 0.05, 0.1, 0.2, 0.4]
gamma = 0.02

Z = zeros(n,1)
X = readdlm(dataset*"/attributes.txt", ' ', Float64, '\n')
edge_weight_z = edge_weights(G, Z, 0.0)
edge_weight_x = edge_weights(G, X, gamma)
# zeros like edge_weight_x
edge_weight_p = edge_weight_z * 0.0
edge_weight_r = edge_weight_z * 0.0

targets = collect(1:num_classes)
F1_CLF = zeros(length(targets), num_trials)
F1_SFD = zeros(length(targets), num_trials)
F1_WFD = zeros(length(targets), num_trials)
F1_LFD_p = zeros(length(targets), num_trials)
F1_LFD_r = zeros(length(targets), num_trials)
F1_PR_z = zeros(length(targets), num_trials)
F1_PR_p = zeros(length(targets), num_trials)
F1_PR_r = zeros(length(targets), num_trials)

min_output_size = 100
sink = [Float64(d) for d in G.degree]
num_pos = num_data
num_neg = num_data

for (i, c) in enumerate(targets)
        
    K = findall(x->x==c, labels)
    Kc = findall(x->x!=c, labels)
    
    vol_K = sum(G.degree[K])
    
    # Create binary labels
    y = zeros(n)
    y[K] .= 1    
    
    for j in 1:num_trials
        # Generate train and test samples
        train_ids = [K[randperm(length(K))[1:num_pos]];Kc[randperm(length(Kc))[1:num_neg]]]
        X_train = X[train_ids,:]
        y_train = y[train_ids]
        
        # Fit a classifier
        clf = LogisticRegression(C=1.0e0, tol=1e-6, class_weight="balanced", max_iter=1000, solver="lbfgs")
        clf.fit(X_train, y_train)
        
        y_pred = clf.predict(X)
        C = findall(x->x==1,y_pred)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        f1 = 2*pr*re/(pr+re)
        F1_CLF[i,j] = f1
        
        # Get edge weights based on predicted probabilities
        y_prob = clf.predict_proba(X)[:,2]
        index = 0
        for i in 1:G.nv
            for (j, v_) in enumerate(G.neighbors[i])
                edge_weight_p[i][j] = y_prob[i]*y_prob[v_]
            end
        end
        
        # Get edge weights based on binary predictions
        y_pred = clf.predict(X)
        for i in 1:G.nv
            for (j, v_) in enumerate(G.neighbors[i])
                edge_weight_r[i][j] = max(y_pred[i]*y_pred[v_], eps)
            end
        end
        
        seeds = train_ids[1:num_pos]
        vol_seeds = sum(G.degree[seeds])
        source = zeros(n)
        source[seeds] = 2*vol_K/vol_seeds*G.degree[seeds]
        rho = 1/sum(source) # for PageRank
        
        # Standard flow diffusion
        x = WFD(G, source, sink, weights=edge_weight_z, max_iters=100, epsilon=1.0e-4)
        C, cond_ = sweepcut(G, x, edge_weight_z, min_size=min_output_size)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        f1 = 2*pr*re/(pr+re)
        F1_SFD[i,j] = f1
        
        # Weighted flow diffusion
        x = WFD(G, source, sink, weights=edge_weight_x, max_iters=100, epsilon=1.0e-4)
        C, cond_ = sweepcut(G, x, edge_weight_z, min_size=min_output_size)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        f1 = 2*pr*re/(pr+re)
        F1_WFD[i,j] = f1
        
        # Label-based flow diffusion with probabilities
        x = WFD(G, source, sink, weights=edge_weight_p, max_iters=100, epsilon=1.0e-4)
        C, cond_ = sweepcut(G, x, edge_weight_z, min_size=min_output_size)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        f1 = 2*pr*re/(pr+re)
        F1_LFD_p[i,j] = f1
    
        # Label-based flow diffusion with predictions
        x = WFD(G, source, sink, weights=edge_weight_r, max_iters=100, epsilon=1.0e-4)
        C, cond_ = sweepcut(G, x, edge_weight_z, min_size=min_output_size)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        f1 = 2*pr*re/(pr+re)
        F1_LFD_r[i,j] = f1
    
        # PageRank with no weights
        source_page_rank = zeros(Float64, n)
        source_page_rank[seeds] = 1.0 ./ vol_seeds .* G.degree[seeds]
        f1_best = 0
        # line-search for alpha
        for alpha in alpha_range
            p = pagerank(G, source_page_rank, rho, alpha, tol=1.0e-7, max_iters=100)
            for v in G.nv
                if p[v] > 0
                    p[v] = p[v]/G.degree[v]
                end
            end
            C, cond_ = sweepcut(G, p, edge_weight_z, min_size=min_output_size)
            pr = length(intersect(Set(C),Set(K)))/length(C)
            re = length(intersect(Set(C),Set(K)))/length(K)
            f1 = 2*pr*re/(pr+re)
            if f1 > f1_best
                f1_best = f1
            end
        end
        F1_PR_z[i,j] = f1_best
    
        # PageRank with probabilities
        source_page_rank = zeros(Float64, n)
        source_page_rank[seeds] = 1.0 ./ vol_seeds .* G.degree[seeds]
        f1_best = 0
        for alpha in alpha_range
            p = pagerank(G, source_page_rank, rho, alpha, tol=1.0e-7, weights=edge_weight_p, max_iters=100)
            for v in G.nv
                if p[v] > 0
                    p[v] = p[v]/sum(edge_weight_p[v])
                end
            end
            C, cond_ = sweepcut(G, p, edge_weight_z, min_size=min_output_size)
            pr = length(intersect(Set(C),Set(K)))/length(C)
            re = length(intersect(Set(C),Set(K)))/length(K)
            f1 = 2*pr*re/(pr+re)
            if f1 > f1_best
                f1_best = f1
            end
        end
        F1_PR_p[i,j] = f1_best
    
        # PageRank with predictions
        source_page_rank = zeros(Float64, n)
        source_page_rank[seeds] = 1.0 ./ vol_seeds .* G.degree[seeds]
        f1_best = 0
        for alpha in alpha_range
            p = pagerank(G, source_page_rank, rho, alpha, tol=1.0e-7, weights=edge_weight_r, max_iters=100)
            for v in G.nv
                if p[v] > 0
                    p[v] = p[v]/sum(edge_weight_r[v])
                end
            end
            C, cond_ = sweepcut(G, p, edge_weight_z, min_size=min_output_size)
            pr = length(intersect(Set(C),Set(K)))/length(C)
            re = length(intersect(Set(C),Set(K)))/length(K)
            f1 = 2*pr*re/(pr+re)
            if f1 > f1_best
                f1_best = f1
            end
        end
        F1_PR_r[i,j] = f1_best
    
        @printf("Class %d, trial %d, F1s are %.2f, %.2f, %.2f, %.2f, %.2f, %.2f, %.2f, %.2f\n", c, j, 
            F1_CLF[i,j], F1_SFD[i,j], F1_WFD[i,j], F1_LFD_p[i,j], F1_LFD_r[i,j], F1_PR_z[i,j], F1_PR_p[i,j], F1_PR_r[i,j]) 
        GC.gc()
    end
end

dir_path = dirname(output)
if !isdir(dir_path)
    mkpath(dir_path)
end

# Open the HDF5 file in write mode
h5file = h5open(output, "w")

# Create a group to store your matrices
group = create_group(h5file, "matrices")

# Write each matrix to the HDF5 file
write(h5file, "matrices/F1_CLF", F1_CLF)
write(h5file, "matrices/F1_SFD", F1_SFD)
write(h5file, "matrices/F1_WFD", F1_WFD)
write(h5file, "matrices/F1_LFD_p", F1_LFD_p)
write(h5file, "matrices/F1_LFD_r", F1_LFD_r)
write(h5file, "matrices/F1_PR_z", F1_PR_z)
write(h5file, "matrices/F1_PR_r", F1_PR_r)
write(h5file, "matrices/F1_PR_p", F1_PR_p)
# Close the file
close(h5file)