using GraphNeuralNetworks, Graphs, Flux, Tullio
using LinearAlgebra: I
using Random


## Linear Dirac Bianconi recurrent convolution layer (Euler Step with the Bianchoni Dirac Equation)

struct LinDBLayer{T1<:AbstractMatrix,T2,T3,T4} <: GNNLayer
    W_ne::T1 # Weights for mapping edge features to node features
    W_en::T2 # Weights for mapping node features to edge features
    beta_e::T3
    beta_n::T4
end

Flux.@functor LinDBLayer

function LinDBLayer(dn, de; mass_matrix=true, balanced=true, Δ, init_function)
    if init_function != "zeros"
        # We initialize the layer close to the identity function
        W_ne = Δ .* init_function(dn, de)

        if balanced
            W_en = nothing
        else
            W_en = Δ .* init_function(de, dn)
        end

        if mass_matrix
            beta_e = Δ .* init_function(de, de) + I
            beta_n = Δ .* init_function(dn, dn) + I
        else
            beta_e = 1.0f0
            beta_n = 1.0f0
        end
    else
        W_ne = Flux.zeros32(dn, de)
        W_en = Flux.zeros32(dn, de)
        beta_e = Flux.zeros32(de, de)
        beta_n = Flux.zeros32(dn, dn)
    end
    LinDBLayer(W_ne, W_en, beta_e, beta_n)
end

function (l::LinDBLayer{T1,Nothing,T3,T4})(g::GNNGraph, x_n, x_e) where {T1<:(AbstractMatrix),T3,T4}
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.W_ne' * (xi - xj) .+ l.beta_e * eij
    end

    x_new = l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n

    x_new, e_new
end

function (l::LinDBLayer)(g::GNNGraph, x_n, x_e)
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.W_en * (xi - xj) .+ l.beta_e * eij
    end

    x_new = l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n

    x_new, e_new
end


##

struct DBnStep
    linDB
    num_steps
    dropout_n
    dropout_e
end

Flux.@functor DBnStep

function DBnStep(d_n, d_e, Δ, balanced, num_steps, dropout_n, dropout_e, init_function)
    # init DBnStep
    linDB = LinDBLayer(
        d_n,
        d_e;
        Δ=Δ, balanced=balanced, init_function=init_function)
    DBnStep(linDB, num_steps, Dropout(dropout_n), Dropout(dropout_e))
end


function (db::DBnStep)(g::GNNGraph, x_n, x_e)
    # forward
    for i in 1:db.num_steps
        x_n, x_e = db.linDB(g, x_n, x_e)
        x_n = db.dropout_n(x_n)
        x_e = db.dropout_e(x_e)
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    return x_n, x_e
end



## DBGNN 
struct DBGNN{T1, T2, T3, T4, T5, T6, T7, T8} # Add type information so that branching on "isnothing" and "maybe_idx" is free.
    dense_n_in::T1
    dense_e_in::T2
    skip_dense_n::T3
    skip_dense_e::T4
    db_nstep_array::T5
    dense_after_linDB::T6
    pool::T7
    final_dense::T8
end

Flux.@functor DBGNN

"""
    DBGNN(d_n_in, d_n_hidden, d_n_out, d_e_in, d_e_hidden, num_dbnstep, steps_dbnstep, Δ; skip_connection_n=true, skip_connection_e=true, skip_n_share_weight=false,  skip_e_share_weight=false, final_pool=true, dense_pre_pool=true, balanced=false, dropout_n=0.05, dropout_e=0.05, init_dense_in, init_lindb, init_dense)
Generates DBGNN model
- `d_n_in`: input dimension of node features
- `d_n_hidden`: hidden dimension of node embedding
- `d_n_out`: output dimension of node embedding
- `d_e_in`: input dimension of edge features
- `d_e_hidden`: hidden dimension of edge embedding
- `num_dbnstep`: number of DBnStep layers
- `steps_dbnstep`: number of steps per DBnStep 
- `Δ`: Δ value, must be float, mass matrix of DB layer
- `skip_connection_n`:Bool if true use skip_connection for nodes
- `skip_connection_e`:Bool if true use skip_connection for edges
- `skip_n_share_weight`: Bool if true share weights for skip connection for nodes
- `skip_e_share_weight`: Bool if true share weights for skip connection for edges
- `final_pool`: use pooling for graph-based setup (Flux.mean or Flux.max) for graph-based setup or false for node setup
- `dense_after_linDB`: if not false, dimension of Dense layer after final DBnStep
- `balanced` Bool for using balanced setup (weight sharing)
- `dropout_n`: node dropout
- `dropout_e`: edge dropout
- `init_dense_in`: initialization function for init Dense layers
- `init_lindb`: initialization function for LinDB layers
- `init_dense`: initialization function for Dense layers
"""
function DBGNN(d_n_in, d_n_hidden, d_n_out, d_e_in, d_e_hidden, num_dbnstep, steps_dbnstep; Δ, skip_connection_n=true, skip_connection_e=true, skip_n_share_weight=false,  skip_e_share_weight=false, final_pool=false, dense_after_linDB=false, balanced=false, dropout_n=0.05, dropout_e=0.05, init_dense_in, init_lindb, init_dense)
    Δ = Float32(Δ / steps_dbnstep)
    dense_n_in = Dense(d_n_in => d_n_hidden; init = init_dense_in)
    dense_e_in = Dense(d_e_in => d_e_hidden; init = init_dense_in)

    ## init layers for skip connections
    if skip_connection_n
        if skip_n_share_weight
            skip_dense_n = Dense(d_n_in + d_n_hidden => d_n_hidden; init = init_dense)
        else
            skip_dense_n = [Dense(d_n_in + d_n_hidden => d_n_hidden; init = init_dense) for _ in i:num_dbnstep]
        end
    else
        skip_dense_n = nothing
    end

    if skip_connection_e
        if skip_e_share_weight
            skip_dense_e = Dense(d_e_in + d_e_hidden => d_e_hidden; init = init_dense)
        else
            skip_dense_e = [Dense(d_e_in + d_e_hidden => d_e_hidden; init = init_dense) for _ in i:num_dbnstep]
        end
    else
        skip_dense_e = nothing
    end

    db_nstep_array = [DBnStep(d_n_hidden, d_e_hidden, Δ, balanced, steps_dbnstep, dropout_n, dropout_e, init_lindb)]

    if dense_after_linDB == false
        dense_after_linDB = nothing
        final_dense_in_dim = d_n_hidden
    else
        dense_after_linDB = Dense(d_n_hidden => dense_after_linDB; init=init_dense)
        final_dense_in_dim = dense_after_linDB
    end
    if final_pool == false
        pool = nothing
    else
        pool = GlobalPool(final_pool)
    end
    final_dense = Dense(final_dense_in_dim => d_n_out; init=init_dense)
    
    DBGNN(
    dense_n_in,
    dense_e_in,
    skip_dense_n,
    skip_dense_e,
    db_nstep_array,
    dense_after_linDB,
    pool,
    final_dense)
end

maybe_idx(a::AbstractArray, i) = a[i] # Take the index if a is a subtype of AbstractArray...
maybe_idx(a::Any, i) = a # ... otherwise ignore the index.

function (model::DBGNN)(g::GNNGraph, x_n_in, x_e_in)
    x_n = model.dense_n_in(x_n_in)
    x_e = model.dense_e_in(x_e_in)
    
    for (i, dbn_step) in enumerate(model.db_nstep_array)
        x_n, x_e = dbn_step(g, x_n, x_e)

        if !isnothing(model.skip_dense_n)
            x_n = vcat(x_n_in, x_n)
            x_n = maybe_idx(model.skip_dense_n, i)(x_n)
        end

        if !isnothing(model.skip_dense_e)
            x_e = vcat(x_e_in, x_e)
            x_e = maybe_idx(model.skip_dense_e, i)(x_e)
        end
    end

    isnothing(model.dense_after_linDB) || (x_n = model.dense_after_linDB(x_n)) # || only executes the rhs if the lhs is false
    
    isnothing(model.pool) || (x_n = model.pool(g, x_n)) # || only executes the rhs if the lhs is false
    
    model.final_dense(x_n)'
end
