using GraphNeuralNetworks, Graphs, Flux, Tullio
using LinearAlgebra: I, norm, Diagonal, tr
using Random


## Linear Dirac Bianconi recurrent convolution layer (Euler Step with the Bianchoni Dirac Equation)

struct LinDBLayer{T1<:AbstractMatrix,T2,T3,T4} <: GNNLayer
    W_ne::T1 # Weights for mapping edge features to node features
    W_en::T2 # Weights for mapping node features to edge features
    beta_e::T3
    beta_n::T4
end

Flux.@functor LinDBLayer

function LinDBLayer(dn, de; mass_matrix=true, balanced=true, Δ=1f0, init_function) # mass matrix = self loops
    if init_function != "zeros"
        # We initialize the layer close to the identity function
        W_ne = Δ .* init_function(dn, de)

        if balanced
            W_en = nothing
        else
            W_en = Δ .* init_function(de, dn)
        end

        if mass_matrix
            beta_e = Δ .* init_function(de, de) + I
            beta_n = Δ .* init_function(dn, dn) + I
        else
            beta_e = 1.0f0
            beta_n = 1.0f0
        end
    else
        W_ne = Flux.zeros32(dn, de)
        W_en = Flux.zeros32(dn, de)
        beta_e = Flux.zeros32(de, de)
        beta_n = Flux.zeros32(dn, dn)
    end
    LinDBLayer(W_ne, W_en, beta_e, beta_n)
end

function (l::LinDBLayer{T1,Nothing,T3,T4})(g::GNNGraph, x_n, x_e) where {T1<:(AbstractMatrix),T3,T4}
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.W_ne' * (xi - xj) .+ l.beta_e * eij
    end

    x_new = l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n

    x_new, e_new
end

function (l::LinDBLayer)(g::GNNGraph, x_n, x_e)
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.W_en * (xi - xj) .+ l.beta_e * eij
    end

    x_new = l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n

    x_new, e_new
end



## node/edgewise up/down mapping
struct UDLayer{T1} <: GNNLayer
    W::T1
end

Flux.@functor UDLayer

function UDLayer(d_in, d_out, init_function)
    W = init_function(d_out, d_in)
    d_smaller = min(d_in, d_out)
    for i in 1:d_smaller
        W[i,i] += 1.
    end
    UDLayer(W)
end

function (l::UDLayer)(x)
    l.W * x
end


struct DBLayerModul
    linDB
    num_steps
    dropout_n
    dropout_e
end

Flux.@functor DBLayerModul

function DBLayerModul(dims_n, dims_e, Δ, balanced, num_steps, dropout_n, dropout_e, init_function) # mass matrix = self loops
    # init DBLayerModul
    linDB = LinDBLayer(
        dims_n,
        dims_e;
        Δ=Δ, balanced=balanced, init_function=init_function)
    DBLayerModul(linDB, num_steps, Dropout(dropout_n), Dropout(dropout_e))
end


function (db::DBLayerModul)(g::GNNGraph, x_n, x_e)
    # forward
    for i in 1:db.num_steps
        x_n, x_e = db.linDB(g, x_n, x_e)
        x_n = db.dropout_n(x_n)
        x_e = db.dropout_e(x_e)
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    return x_n, x_e
end


## DBGCModel 
struct DBGCModel
    ud_n_in
    ud_e_in
    udl_deep_n
    udl_deep_e
    dbl
    steps_dbl
    skip_connection_n
    skip_connection_e
    skip_connection_n_share_weights
    skip_connection_e_share_weights
    dropout_n
    dropout_e
    dense_after_db
    dense
    Δ::Float32
    pool
end

Flux.@functor DBGCModel

"""
DBGCModel(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, dense_after_db_dim=false, dense_after_gat_dim=false, Δ, pool=false, target_classes, init_ud, init_lindb, init_dense)

Generates DBGCModel

- `dims_n::Vector{Int}`: Vector with dimensions for nodal embeddings of DB Layer
- `dims_n::Vector{Int}`: Vector with dimensions for edge embeddings of DB Layer
- `balanced` Bool for using balanced setup (weight sharing)
- `dropout_n`: node dropout
- `dropout_e`: edge dropout
- `dense_after_db_dim`: dimension of dense layer after DB, set false if no dense layer
- `Δ`: Δ value, must be float, mass matrix of DB layer
- `pool`: use pooling for graph-based setup (Flux.mean or Flux.max) for graph-based setup or false for node setup
- `target_classes`: number of target classes
- `init_ud`: initialization function for UD layers
- `init_lindb`: initialization function for LinDB layers
- `init_dense`: initialization function for Dense layers

"""
function DBGCModel(dims_ud_n::Vector{Int}, dims_ud_e::Vector{Int}, d_hidden_n, d_hidden_e, num_dbl, steps_dbl; balanced=false, skip_connection_n=false, skip_connection_e=false, skip_connection_n_share_weights=false, skip_connection_e_share_weights=false, dropout_n=0.05, dropout_e=0.05, dense_dim=false,  Δ, pool=false, target_classes, init_ud, init_lindb, init_dense)
    @assert length(dims_ud_n) == 2
    @assert length(dims_ud_e) == 2
    # @assert length(dims_n) == length(dims_e)
    Δ = Δ / Float32(steps_dbl)
    ud_n_in = UDLayer(
        dims_ud_n[1],
        dims_ud_n[2],
        init_ud)
    ud_e_in = UDLayer(
        dims_ud_e[1],
        dims_ud_e[2],
        init_ud)
    dbl_deep = Array{DBLayerModul,1}(undef, num_dbl)
    ## init layers for skip connections
    if skip_connection_n == true
        if skip_connection_n_share_weights == true
            udl_deep_n = UDLayer(d_hidden_n, dims_ud_n[2],init_ud)
        else
            udl_deep_n = Array{UDLayer}(undef, num_dbl)
            for i in 1:num_dbl
                udl_deep_n[i] = UDLayer(d_hidden_n, dims_ud_n[2],init_ud)
            end
        end
    else
        udl_deep_n = false
    end
    if skip_connection_e == true
        if skip_connection_e_share_weights == true
            udl_deep_e = UDLayer(d_hidden_e, dims_ud_e[2], init_ud)
        else
            udl_deep_e = Array{UDLayer}(undef, num_dbl)
            for i in 1:num_dbl
                udl_deep_e[i] = UDLayer(d_hidden_e, dims_ud_e[2], init_ud)
            end
        end
    else
        udl_deep_e = false
    end
    ## init DBLayerModul
    for i in 1:num_dbl
        dbl_deep[i] = DBLayerModul(dims_ud_n[2], dims_ud_e[2], Δ, balanced, steps_dbl, dropout_n, dropout_e, init_lindb)
    end
    if dense_dim != false
        dense_after_db = Dense(dims_ud_n[2], dense_dim, init=init_dense)
        dense = Dense(dense_dim, target_classes, init=init_dense)
    else
        dense_after_db = false
        dense = Dense(dims_ud_n[2], target_classes, init=init_dense)    
    end
    if pool != false
        pool = GlobalPool(pool)
    end
    DBGCModel(ud_n_in, ud_e_in, udl_deep_n, udl_deep_e, dbl_deep, steps_dbl, skip_connection_n, skip_connection_e, skip_connection_n_share_weights, skip_connection_e_share_weights, Dropout(dropout_n), Dropout(dropout_e), dense_after_db, dense, Δ, pool)
end

function (model::DBGCModel)(g::GNNGraph, x_n, x_e)
    if model.skip_connection_n == true
        orig_n_features = g.ndata.features
    end
    if model.skip_connection_e == true
        orig_e_features = g.edata.features
    end
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    num_dbl = length(model.dbl)
    for i in 1:num_dbl
        x_n, x_e = model.dbl[i](g, x_n, x_e)
        if model.skip_connection_n == true
            x_n = vcat(g.ndata.features, x_n)
            if model.skip_connection_n_share_weights == true
                x_n = model.udl_deep_n(x_n)
            else
                x_n = model.udl_deep_n[i](x_n)
            end
        end
        if model.skip_connection_e == true
            x_e = vcat(g.edata.features, x_e)
            if model.skip_connection_e_share_weights == true
                x_e = model.udl_deep_e(x_e)
            else
                x_e = model.udl_deep_e[i](x_e)
            end
        end
    end
    if model.dense_after_db != false
        x_n = model.dense_after_db(x_n)
    end
    if model.pool != false
        x_n = model.pool(g, x_n)
    end
    model.dense(x_n)'
end

DirichletE(g, x) = sum([norm(x[:, src(e)] .- x[:, dst(e)]) for e in edges(g)])

function DirichletEMatrix(g, x)
    A = adjacency_matrix(g)
    D = Diagonal(A * ones(g.num_nodes))
    L = D - A
    return (tr(transpose(x) * L * x) / tr(x * transpose(x)))
end

function DirichletEnergy(model::DBGCModel, g::GNNGraph, mode_matrix=false)
    x_n = g.ndata.features
    x_e = g.edata.features
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    energies = Float32[]
    num_dbl = length(model.dbl)
    for i in 1:num_dbl
        x_n, x_e = model.dbl[i](g, x_n, x_e)
        append!(energies, DirichletEnergy(model.dbl[i], g, x_n, x_e, mode_matrix))
    end
    energies
end

function DirichletEnergy(db, g::GNNGraph, x_n, x_e, mode_matrix)
    energies = Float32[]
    for i in 1:db.num_steps
        x_n, x_e = db.linDB(g, x_n, x_e)
        x_n = db.dropout_n(x_n)
        x_e = db.dropout_e(x_e)
        x_n = relu.(x_n)
        x_e = relu.(x_e)
        if mode_matrix == false
            push!(energies, DirichletE(g, x_n))
        else
            push!(energies, DirichletEMatrix(g, x_n'))
        end
    end
    return energies
end
