include("../WFD/WFD.jl")
include("../WFD/common.jl")

using Random
using Printf
using DelimitedFiles
using Statistics
using LinearAlgebra

edges = readdlm("data/synthetic/sbm_p05q0025.txt", ' ', Int, '\n')
n = 10000
ll = [Vector{Int64}() for _ in 1:n]
for e in eachrow(edges)
    push!(ll[e[1]], e[2])
    push!(ll[e[2]], e[1])
end
degree = [length(l) for l in ll]
G = AdjacencyList(ll, degree, n)

k = 500
Z = zeros(n,1)
edge_weight_z = edge_weights(G, Z, 0.)
edge_weight_r = edge_weight_z * 0.0

num_trials = 100
eps = 0.05
flip_ratios = collect(0.5:-0.02:0.02)

F1_SFD = zeros(length(flip_ratios), num_trials)
F1_LFD = zeros(length(flip_ratios), num_trials)
F1_LBL = zeros(length(flip_ratios), num_trials)

sink = ones(n)

for (i,r) in enumerate(flip_ratios)
    
    for t in 1:num_trials
    
        cluster_id = rand(1:20)
        K = collect((cluster_id-1)*k+1:cluster_id*k)
        
        # Create random noisy labels and label-based edge weights
        labels = zeros(Int64, n)
        labels[K] .= 1
        labels[K[randperm(k)[1:Int64(floor(r*k))]]] .= 0
        Kc = collect(setdiff(Set(1:n),Set(K)))
        labels[Kc[randperm(n-k)[1:Int64(floor(r*(n-k)))]]] .= 1
        for i in 1:G.nv
            for (idx,j) in enumerate(G.neighbors[i])
                edge_weight_r[i][idx] = max(labels[i]*labels[j], eps)
            end
        end
        
        # Compute F1 of labels
        C = findall(x->x==1,labels)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        F1_LBL[i,t] = 2*pr*re/(pr+re)
        
        # Randomly select a seed node
        s = rand(K)
        source = zeros(n)
    
        # Standard Flow Diffusion
        best_f1 = 0.0
        for a in collect(2:0.25:4)
            source[s] = a*k
            x = WFD(G, source, sink, weights=edge_weight_z, max_iters=100, epsilon=1.0e-8)
            C = findall(!iszero,x)
            pr = length(intersect(Set(C),Set(K)))/length(C)
            re = length(intersect(Set(C),Set(K)))/length(K)
            f1 = 2*pr*re/(pr+re)
            if f1 > best_f1
                best_f1 = f1
            end
        end
        F1_SFD[i,t] = best_f1
        
        # Label-based Flow Diffusion
        best_f1 = 0.0
        for a in collect(2:0.25:4)
            source[s] = a*k
            x = WFD(G, source, sink, weights=edge_weight_r, max_iters=100, epsilon=1.0e-8)
            C = findall(!iszero,x)
            pr = length(intersect(Set(C),Set(K)))/length(C)
            re = length(intersect(Set(C),Set(K)))/length(K)
            f1 = 2*pr*re/(pr+re)
            if f1 > best_f1
                best_f1 = f1
            end
        end
        F1_LFD[i,t] = best_f1
        
        @printf("Label acc %.2f, trial %d, F1s are label %.2f, SFD %.2f, LFD %.2f\n", 1-r, t,
            F1_LBL[i,t], F1_SFD[i,t], F1_LFD[i,t])
        
    end
    
end

writedlm("F1_Labels.txt", F1_LBL)
writedlm("F1_FD.txt", F1_SFD)
writedlm("F1_LFD.txt", F1_LFD)