using SparseArrays

include("struct.jl")

"""
Computes the L1-regularized PageRank as described in:
    'Variational Perspective on Local Graph Clustering. K. Fountoulakis et al.'

Inputs:
            G - Adjacency list representation of graph.
                Node indices must start from 1 and end with n, where n is the
                number of nodes in G.

       source - Source mass assigned to the seed node.

       rho - Regularization parameter.

       alpha - Teleportation parameter.

       tol - Tolerance for convergence.

       weights - Weights for each edge. If not provided, all edges are assumed
                  to have unit weight.

    max_iters - Maximum number of iterations.


Returns:
            x - Node embeddings generated by PageRank
"""

function pagerank(G::AdjacencyList, s::Vector{Float64}, rho::Float64, alpha::Float64; tol::Float64=1.0e-7, weights=nothing, max_iters::Int=1000)
    
    if weights === nothing
        d = G.degree
        weights = Dict{Int64,Vector{Float64}}()
        for v in 1:G.nv
            weights[v] = ones(Float64, length(G.neighbors[v]))
        end
    else
        d = [sum(weights[v]) for v in 1:G.nv]
    end

    g = -alpha*s # gradient
    p = zeros(Float64, G.nv) # PageRank vector
    S = Set(findall(!iszero, s)) # Set of nodes with non-zero mass

    for k = 1:max_iters
        err = 0.0
        C = [v for v in S if -g[v] > rho*alpha*sqrt(G.degree[v]) + tol] # Set of nodes to update
        for v in shuffle!(C)
            prev = p[v]
            p[v] -= g[v]
            p[v] = max(p[v] - rho*alpha*sqrt(G.degree[v]), 0)
            push = p[v] - prev
            if push > 0
                err = max(err, push)
                g[v] += (alpha + (1 - alpha)/2)*push
                for (i, u) in enumerate(G.neighbors[v])
                    if g[u] == 0
                        push!(S, u)
                    end
                    g[u] -= weights[v][i]*(1 - alpha)/(2*sqrt(d[u])*sqrt(d[v]))*push
                end
            end
        end
        if err < tol
            break
        end
    end
    return p
end