include("../WFD/WFD.jl")
include("../WFD/common.jl")
include("../WFD/pagerank.jl")

using Random
using Printf
using DelimitedFiles
using Statistics
using LinearAlgebra

edges = readdlm("data/synthetic/sbm_p05q0025.txt", ' ', Int, '\n')
n = 10000
ll = [Vector{Int64}() for _ in 1:n]
for e in eachrow(edges)
    push!(ll[e[1]], e[2])
    push!(ll[e[2]], e[1])
end
degree = [length(l) for l in ll]
G = AdjacencyList(ll, degree, n)

k = 500
Z = zeros(n,1)
edge_weight_z = edge_weights(G, Z, 0.)
edge_weight_r = edge_weight_z * 0.0

num_trials = 100
eps = 0.05
flip_ratios = collect(0.5:-0.02:0.4)

F1_SPR = zeros(length(flip_ratios), num_trials)
F1_LPR = zeros(length(flip_ratios), num_trials)
F1_LBL = zeros(length(flip_ratios), num_trials)

for (i,r) in enumerate(flip_ratios)
    
    for t in 1:num_trials
    
        cluster_id = rand(1:20)
        K = collect((cluster_id-1)*k+1:cluster_id*k)
        volK = sum(G.degree[K])
        
        # Create random noisy labels and label-based edge weights
        labels = zeros(Int64, n)
        labels[K] .= 1
        labels[K[randperm(k)[1:Int64(floor(r*k))]]] .= 0
        Kc = collect(setdiff(Set(1:n),Set(K)))
        labels[Kc[randperm(n-k)[1:Int64(floor(r*(n-k)))]]] .= 1
        for i in 1:G.nv
            for (idx,j) in enumerate(G.neighbors[i])
                edge_weight_r[i][idx] = max(labels[i]*labels[j], eps)
            end
        end
        
        # Compute F1 of labels
        C = findall(x->x==1,labels)
        pr = length(intersect(Set(C),Set(K)))/length(C)
        re = length(intersect(Set(C),Set(K)))/length(K)
        F1_LBL[i,t] = 2*pr*re/(pr+re)
        
        # Randomly select a seed node
        s = rand(K)
        source = zeros(n)
        source[s] = 1
    
        # Standard PageRank
        best_f1 = 0.0
        for a in collect(1.25:0.25:2.5)
            rho = 1/(a*volK)
            for alpha in collect(0.05:0.05:0.2)
                p = pagerank(G, source, rho, alpha, tol=1.0e-7, max_iters=30)
                C = findall(!iszero,p)
                pr = length(intersect(Set(C),Set(K)))/length(C)
                re = length(intersect(Set(C),Set(K)))/length(K)
                f1 = 2*pr*re/(pr+re)
                if f1 > best_f1
                    best_f1 = f1
                end
            end
        end
        F1_SPR[i,t] = best_f1
        
        # Label-based PageRank
        best_f1 = 0.0
        for a in collect(1.25:0.25:2.5)
            rho = 1/(a*volK)
            for alpha in collect(0.05:0.05:0.2)
                p = pagerank(G, source, rho, alpha, tol=1.0e-7, weights=edge_weight_r, max_iters=30)
                C = findall(!iszero,p)
                pr = length(intersect(Set(C),Set(K)))/length(C)
                re = length(intersect(Set(C),Set(K)))/length(K)
                f1 = 2*pr*re/(pr+re)
                if f1 > best_f1
                    best_f1 = f1
                end
            end
        end
        F1_LPR[i,t] = best_f1
        
        @printf("Label acc %.2f, trial %d, F1s are label %.2f, SPR %.2f, LPR %.2f\n", 1-r, t,
            F1_LBL[i,t], F1_SPR[i,t], F1_LPR[i,t])
        flush(stdout)
        
    end
    
end

writedlm("F1_Labels.txt", F1_LBL)
writedlm("F1_PR.txt", F1_SPR)
writedlm("F1_LPR.txt", F1_LPR)