using GraphNeuralNetworks, Graphs, Flux, Tullio
using LinearAlgebra: I


## Dirac Bianconi recurrent convolution driven by source layer

struct DBsLayer{T1<:AbstractMatrix,T2<:AbstractMatrix,Ts,T3,T4,T5,T6} <: GNNLayer
    W_ne::T1 # Weights for mapping edge features to node features
    W_en::T2 # Weights for mapping node features to edge features
    sigma::Ts
    beta_e::T3
    beta_n::T4
    S_n::T5
    S_e::T6
end

Flux.@functor DBsLayer

function DBsLayer(dn_in, de_in, dn_out, de_out, ds_n, ds_e; sigma=relu)
    W_ne = Flux.glorot_uniform(dn_out, de_in)
    W_en = Flux.glorot_uniform(de_out, dn_in)

    if ds_e == 0
        S_e = 0.0f0
    else
        S_e = Flux.glorot_uniform(de_out, ds_e)
    end

    if ds_n == 0
        S_n = 0.0f0
    else
        S_n = Flux.glorot_uniform(dn_out, ds_n)
    end

    beta_e = Flux.glorot_uniform(de_out, de_in)
    beta_n = Flux.glorot_uniform(dn_out, dn_in)
    DBsLayer(W_ne, W_en, sigma, beta_e, beta_n, S_n, S_e)
end

function (l::DBsLayer)(g::GNNGraph, x_n, x_e, s_n, s_e)
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.sigma.(l.W_en * (xi - xj) .+ l.beta_e * eij .+ l.S_e * s_e)
    end

    x_new = l.sigma.(l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n .+ l.S_n * s_n)

    x_new, e_new
end


function (l::DBsLayer)(g::GNNGraph, x_n, x_e)
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.sigma.(l.W_en * (xi - xj) .+ l.beta_e * eij)
    end

    x_new = l.sigma.(l.W_ne * aggregate_neighbors(g, +, x_e) .+ l.beta_n * x_n)

    x_new, e_new
end

## Dirac Bianconi recurrent convolution layer (Euler Step with the Bianchoni Dirac Equation with activation function)

struct DBLayer{T1<:AbstractMatrix,T2,Ts,T3,T4} <: GNNLayer
    W_ne::T1 # Weights for mapping edge features to node features
    W_en::T2 # Weights for mapping node features to edge features
    sigma::Ts
    beta_e::T3
    beta_n::T4
end

Flux.@functor DBLayer

function DBLayer(dn, de; sigma=relu, mass_matrix=true, balanced=true, Δ=0.01f0)
    # We initialize the layer close to the identity function
    W_ne = zeros(Float32, dn, de)

    if balanced
        W_en = nothing
    else
        W_en = Δ .* randn(Float32, de, dn)
    end

    if mass_matrix
        beta_e = Δ .* randn(Float32, de, de) + I
        beta_n = Δ .* randn(Float32, dn, dn) + I
    else
        beta_e = 1.0f0
        beta_n = 1.0f0
    end
    DBLayer(W_ne, W_en, sigma, beta_e, beta_n)
end

function (l::DBLayer{T1,Nothing,Ts,T3,T4})(g::GNNGraph, x_n, x_e) where {T1<:(AbstractMatrix),Ts,T3,T4}
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.sigma.(l.W_ne' * (xi - xj)) .+ l.beta_e * eij
    end

    x_new = l.sigma.(l.W_ne * aggregate_neighbors(g, +, x_e)) .+ l.beta_n * x_n

    x_new, e_new
end

function (l::DBLayer)(g::GNNGraph, x_n, x_e)
    e_new = apply_edges(g, xi=x_n, xj=x_n, e=x_e) do xi, xj, eij
        l.sigma.(l.W_en * (xi - xj)) .+ l.beta_e * eij
    end

    x_new = l.sigma.(l.W_ne * aggregate_neighbors(g, +, x_e)) .+ l.beta_n * x_n

    x_new, e_new
end



## Quadratic Edge Layer (my suspicion is that something like that can be used to mimic attention...)

struct QELayer{T1,T2} <: GNNLayer
    W_eee::T1
    pt::T2
end

Flux.@functor QELayer

function QELayer(de)
    W_eee = zeros(Float32, de, de, de)
    pt = zeros(Float32, de, de) + I
    QELayer(W_eee, pt)
end

function (l::QELayer)(x_e)
    @tullio e_new[i, j] := l.W_eee[i, k, l] * x_e[k, j] * x_e[l, j] + l.pt[i, m] * x_e[m, j]
    return e_new
end



## edge init
struct EdgeInitLayer{T1<:AbstractMatrix,T2} <: GNNLayer
    W::T1
    b::T2
end

Flux.@functor EdgeInitLayer

function EdgeInitLayer(d_in::Int, d_out::Int)
    W = zeros(Float32, d_out, d_in)
    b = ones(Float32, d_out)
    EdgeInitLayer(W, b)
end

function (l::EdgeInitLayer)(g, x)
    e_new = apply_edges(g, xi=x, xj=x, e=nothing) do xi, xj, _
        l.W * (xi .- xj) .+ l.b
    end
    e_new
end



## DBNN_s varying layers
struct DBNN_s
    dbls
end

Flux.@functor DBNN_s

function DBNN_s(dims_n::Vector{Int}, dims_e::Vector{Int}; n=length(dims_e) - 1)
    @assert length(dims_n) == length(dims_e)
    @assert n <= length(dims_n) - 1
    dbls = [
        DBsLayer(
            dims_n[i],
            dims_e[i],
            dims_n[i+1],
            dims_e[i+1],
            dims_n[1],
            dims_e[1]
        )
        for i in 1:n]
    DBNN_s(dbls)
end

function (model::DBNN_s)(g::GNNGraph, s_n, s_e)
    x_n, x_e = model.dbls[1](g, s_n, s_e, s_n, s_e)
    for (step, dbl) in enumerate(model.dbls[2:end])
        x_n, x_e = dbl(g, x_n, x_e, s_n, s_e)
    end
    return x_n
end


## DBNN with Dirac-Bianchoni equation in the middle.
struct DBNN_deep
    dbl_in
    dbl_deep
    dbl_out
    n::Int
end

Flux.@functor DBNN_deep

function DBNN_deep(dims_n::Vector{Int}, dims_e::Vector{Int}, n)
    @assert length(dims_n) == length(dims_e) == 3
    Δ = 0.01f0 / Float32(n)
    dbl_in = DBsLayer(
        dims_n[1],
        dims_e[1],
        dims_n[2],
        dims_e[2],
        dims_n[1],
        dims_e[1])
    dbl_deep = DBLayer(
        dims_n[2],
        dims_e[2]; Δ=Δ)
    dbl_out = DBsLayer(
        dims_n[2],
        dims_e[2],
        dims_n[3],
        dims_e[3],
        0,
        0)
    DBNN_deep(dbl_in, dbl_deep, dbl_out, n)
end

function (model::DBNN_deep)(g::GNNGraph, s_n, s_e)
    x_n, x_e = model.dbl_in(g, s_n, s_e, s_n, s_e)
    for i in 1:model.n
        x_n, x_e = model.dbl_deep(g, x_n, x_e)
    end
    x_n, x_e = model.dbl_out(g, x_n, x_e)
    return x_n
end


## Direct Dirac-Bianchoni dynamics with simple initial conditions and readout.
struct DBEquation
    ud_in
    ei
    dbl
    dropout
    ud_out
    n::Int
end

Flux.@functor DBEquation

function DBEquation(dims_n::Vector{Int}, d_e::Int, n; balanced=true)
    @assert length(dims_n) == 3
    Δ = 0.01f0 / Float32(n)
    ud_in = UDLayer(
        dims_n[1],
        dims_n[2])
    ei = EdgeInitLayer(dims_n[1], d_e)
    dbl_deep = LinDBLayer(
        dims_n[2],
        d_e; Δ=Δ, balanced=balanced)
    ud_out = UDLayer(
        dims_n[2],
        dims_n[3])
    DBEquation(ud_in, ei, dbl_deep, Dropout(0.005), ud_out, n)
end

function (model::DBEquation)(g::GNNGraph, x_n)
    x_e = model.ei(g, x_n)
    x_n = model.ud_in(x_n)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout(x_n)
        x_e = model.dropout(x_e)
    end
    x_n = relu.(x_n)
    model.ud_out(x_n)
end


## Direct Dirac-Bianchoni dynamics with simple initial conditions and readout used for graph regression.
struct DBEquationGraphRegression
    ud_in
    ei
    dbl
    dropout
    ud_out
    n::Int
    globalPool
    dense
end

Flux.@functor DBEquationGraphRegression

function DBEquationGraphRegression(dims_n::Vector{Int}, d_e::Int, n)
    @assert length(dims_n) == 3
    Δ = 0.01f0 / Float32(n)
    ud_in = UDLayer(
        dims_n[1],
        dims_n[2])
    ei = EdgeInitLayer(dims_n[1], d_e)
    dbl_deep = LinDBLayer(
        dims_n[2],
        d_e; Δ=Δ)
    ud_out = UDLayer(
        dims_n[2],
        dims_n[3])
    globalPool = GlobalPool(Flux.mean)
    dense = Dense(dims_n[end], 1)
    DBEquationGraphRegression(ud_in, ei, dbl_deep, Dropout(0.005), ud_out, n, globalPool, dense)
end

function (model::DBEquationGraphRegression)(g::GNNGraph, x_n)
    x_e = model.ei(g, x_n)
    x_n = model.ud_in(x_n)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout(x_n)
        x_e = model.dropout(x_e)
    end
    x_n = relu.(x_n)
    out = model.ud_out(x_n)
    out = model.globalPool(g, out)
    model.dense(out)
end


## Direct Dirac-Bianchoni dynamics with simple initial conditions and readout used for graph regression.
struct DBEquationGraphRegression2
    ud_in
    # ei
    dbl
    dropout
    ud_out
    n::Int
    globalPool
    dense
end

Flux.@functor DBEquationGraphRegression2

function DBEquationGraphRegression2(dims_n::Vector{Int}, d_e::Int, n)
    @assert length(dims_n) == 3
    Δ = 0.01f0 / Float32(n)
    ud_in = UDLayer(
        dims_n[1],
        dims_n[2])
    # ei = EdgeInitLayer(dims_n[1], d_e)
    dbl_deep = LinDBLayer(
        dims_n[2],
        d_e; Δ=Δ)
    ud_out = UDLayer(
        dims_n[2],
        dims_n[3])
    globalPool = GlobalPool(Flux.mean)
    dense = Dense(dims_n[end], 1)
    DBEquationGraphRegression2(ud_in, dbl_deep, Dropout(0.005), ud_out, n, globalPool, dense)
end

function (model::DBEquationGraphRegression2)(g::GNNGraph, x_n, x_e)
    # x_e = model.ei(g, x_n)
    x_n = model.ud_in(x_n)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout(x_n)
        x_e = model.dropout(x_e)
    end
    x_n = relu.(x_n)
    out = model.ud_out(x_n)
    out = model.globalPool(g, out)
    model.dense(out)
end






## Direct Dirac-Bianchoni dynamics with simple initial conditions and readout used for graph regression.

struct DBEquation_GR
    ud_n_in
    ud_e_in
    dbl
    dropout_n
    dropout_e
    batchnorm_n
    batchnorm_e
    n::Int
    globalPool
    dense
    dense_after_conv
    dense_after_pool
    Δ
end

Flux.@functor DBEquation_GR

function DBEquation_GR(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, batchnorm_n=false, batchnorm_e=false, dense_after_conv_dim=false, dense_after_pool_dim=false, pool_criterion=Flux.mean, Δ)
    @assert length(dims_n) == 2
    @assert length(dims_e) == 2
    Δ = Δ / Float32(n)
    ud_n_in = UDLayer(
        dims_n[1],
        dims_n[2])
    ud_e_in = UDLayer(
        dims_e[1],
        dims_e[2])
    dbl_deep = LinDBLayer(
        dims_n[2],
        dims_e[2];
        Δ=Δ, balanced=balanced)
    globalPool = GlobalPool(pool_criterion)
    if batchnorm_n != false
        batchnorm_n = BatchNorm(dims_n[2])
    end
    if batchnorm_e != false
        batchnorm_e = BatchNorm(dims_e[2])
    end
    if dense_after_conv_dim != false
        dense_after_conv = Dense(dims_n[2], dense_after_conv_dim)
        if dense_after_pool_dim != false
            dense_after_pool = Dense(dense_after_conv_dim, dense_after_pool_dim)
            dense = Dense(dense_after_pool_dim, 1)
        else
            dense_after_pool = false
            dense = Dense(dense_after_conv_dim, 1)

        end
    else
        dense_after_conv = false
        if dense_after_pool_dim != false
            dense_after_pool = Dense(dims_n[2], dense_after_pool_dim)
            dense = Dense(dense_after_pool_dim, 1)
        else
            dense_after_pool = false
            dense = Dense(dims_n[2], 1)
        end
    end
    DBEquation_GR(ud_n_in, ud_e_in, dbl_deep, Dropout(dropout_n), Dropout(dropout_e), batchnorm_n, batchnorm_e, n, globalPool, dense, dense_after_conv, dense_after_pool, Δ)
end

function (model::DBEquation_GR)(g::GNNGraph, x_n, x_e)
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout_n(x_n)
        x_e = model.dropout_e(x_e)
        if model.batchnorm_n != false
            x_n = model.batchnorm_n(x_n)
        end
        if model.batchnorm_e != false
            x_e = model.batchnorm_e(x_e)
        end
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    if model.dense_after_conv != false
        x_n = model.dense_after_conv(x_n)
    end
    out = model.globalPool(g, x_n)
    if model.dense_after_pool != false
        out = model.dense_after_pool(out)
    end
    model.dense(out)
end
##

struct Bench_GR
    globalPool
    dense
end

Flux.@functor Bench_GR

function Bench_GR(dim_n::Int)
    globalPool = GlobalPool(Flux.mean)
    dense = Dense(dim_n, 1)
    Bench_GR(globalPool, dense)
end

function (model::Bench_GR)(g::GNNGraph, x_n, x_e)
    out = model.globalPool(g, x_n)
    model.dense(out)
end

##

struct DBEquation_NR
    ud_n_in
    ud_e_in
    dbl
    dropout_n
    dropout_e
    batchnorm_n
    batchnorm_e
    n::Int
    dense
    dense_after_conv
    Δ::Float32
end

Flux.@functor DBEquation_NR

function DBEquation_NR(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, batchnorm_n=false, batchnorm_e=false, dense_after_conv_dim=false, Δ)
    @assert length(dims_n) == 2
    @assert length(dims_e) == 2
    Δ = Δ / Float32(n)
    ud_n_in = UDLayer(
        dims_n[1],
        dims_n[2])
    ud_e_in = UDLayer(
        dims_e[1],
        dims_e[2])
    dbl_deep = LinDBLayer(
        dims_n[2],
        dims_e[2];
        Δ=Δ, balanced=balanced)
    if batchnorm_n != false
        batchnorm_n = BatchNorm(dims_n[2])
    end
    if batchnorm_e != false
        batchnorm_e = BatchNorm(dims_e[2])
    end
    if dense_after_conv_dim != false
        dense_after_conv = Dense(dims_n[2], dense_after_conv_dim)
        dense = Dense(dense_after_conv_dim, 1)
    else
        dense_after_conv = false
        dense = Dense(dims_n[2], 1)
    end
    DBEquation_NR(ud_n_in, ud_e_in, dbl_deep, Dropout(dropout_n), Dropout(dropout_e), batchnorm_n, batchnorm_e, n, dense, dense_after_conv, Δ)
end

function (model::DBEquation_NR)(g::GNNGraph, x_n, x_e)
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout_n(x_n)
        x_e = model.dropout_e(x_e)
        if model.batchnorm_n != false
            x_n = model.batchnorm_n(x_n)
        end
        if model.batchnorm_e != false
            x_e = model.batchnorm_e(x_e)
        end
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    if model.dense_after_conv != false
        x_n = model.dense_after_conv(x_n)
    end
    model.dense(x_n)'
end


## DBGATModel 
struct DBGAT_Model
    ud_n_in
    ud_e_in
    dbl
    dropout_n
    dropout_e
    n::Int
    dense_after_db
    GAT
    dense_after_gat
    dense
    Δ::Float32
    pool
end

Flux.@functor DBGAT_Model

"""
    DBGAT_Model(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, dense_after_db_dim=false, dense_after_gat_dim=false, Δ, use_GAT, GAT_heads, GAT_bias, GAT_concat, GAT_negative_slope, GAT_output_n_dim, pool=false)

Generates DBGAT_Model

- `dims_n::Vector{Int}`: Vector with dimensions for nodal embeddings of DB Layer
- `dims_n::Vector{Int}`: Vector with dimensions for edge embeddings of DB Layer
- `balanced` Bool for using balanced setup (weight sharing)
- `dropout_n`: node dropout
- `dropout_e`: edge dropout
- `dense_after_db_dim`: dimension of dense layer after DB, set false if no dense layer
- `dense_after_gat_dim`: dimension of dense layer after GAT, set false if no dense layer
- `Δ`: Δ value, must be float, mass matrix of DB layer
- `use_GAT`: use GAT or not
- `GAT_heads`: number of heads for GAT
- `GAT_bias`: use bias or not
- `GAT_concat`: concat heads or not
- `GAT_negative_slope`: negative slope of Leaky Relu for GAT
- `GAT_output_n_dim`: output dimension for GAT
- `pool`: use pooling for graph-based setup (Flux.mean or Flux.max) for graph-based setup or false for node setup
- `target_classes`: number of target classes

"""
function DBGAT_Model(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, dense_after_db_dim=false, dense_after_gat_dim=false, Δ, use_GAT, GAT_heads, GAT_bias, GAT_concat, GAT_negative_slope, GAT_output_n_dim, pool=false, target_classes)
    @assert length(dims_n) == 2
    @assert length(dims_e) == 2
    Δ = Δ / Float32(n)
    ud_n_in = UDLayer(
        dims_n[1],
        dims_n[2])
    ud_e_in = UDLayer(
        dims_e[1],
        dims_e[2],)
    dbl_deep = LinDBLayer(
        dims_n[2],
        dims_e[2];
        Δ=Δ, balanced=balanced)
    # dense_after_db
    dense_after_db_dim_in = dims_n[2]
    dense_after_db_dim_out = dense_after_db_dim
    if dense_after_db_dim != false
        dense_after_db = Dense(dense_after_db_dim_in, dense_after_db_dim_out, init=init_function)
    else
        dense_after_db = false
        dense_after_db_dim_out = dims_n[2]
    end

    if dense_after_db_dim != false
        GAT_n_in = dense_after_db_dim_out
    else
        GAT_n_in = dims_n[2]
    end

    if use_GAT
        GAT = GATv2Conv((GAT_n_in, dims_e[2]) => GAT_output_n_dim, heads=GAT_heads, concat=GAT_concat, bias=GAT_bias, negative_slope=GAT_negative_slope, add_self_loops=false)
        dense_after_gat_dim_in = GAT_output_n_dim * GAT_heads
    else
        dense_after_gat_dim = false
        GAT_output_n_dim = dense_after_db_dim_out
        GAT = false
        dense_after_gat_dim_in = dense_after_db_dim_out
    end

    if dense_after_gat_dim != false
        dense_after_gat_dim_out = dense_after_gat_dim
        dense_after_gat = Dense(dense_after_gat_dim_in, dense_after_gat_dim_out, init=init_function)
    else
        dense_after_gat = false
        if use_GAT == false
            dense_after_gat_dim_out = GAT_output_n_dim
        else
            dense_after_gat_dim_out = GAT_output_n_dim * GAT_heads
        end
    end

    dense = Dense(dense_after_gat_dim_out, target_classes, init=init_function)

    if pool != false
        pool = GlobalPool(pool)
    end
    DBGAT_Model(ud_n_in, ud_e_in, dbl_deep, Dropout(dropout_n), Dropout(dropout_e), n, dense_after_db, GAT, dense_after_gat, dense, Δ, pool)
end

function (model::DBGAT_Model)(g::GNNGraph, x_n, x_e)
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    for i in 1:model.n
        x_n, x_e = model.dbl(g, x_n, x_e)
        x_n = model.dropout_n(x_n)
        x_e = model.dropout_e(x_e)
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    if model.dense_after_db != false
        x_n = model.dense_after_db(x_n)
    end
    if model.GAT != false
        x_n = model.GAT(g, x_n, x_e)
    end
    if model.pool != false
        x_n = model.pool(g, x_n)
    end
    if model.dense_after_gat != false
        x_n = model.dense_after_gat(x_n)
    end
    model.dense(x_n)'
end



Flux.@functor DBLayerModul

function DBLayerModul(dims_n, dims_e, Δ, balanced, num_steps, dropout_n, dropout_e, init_function) # mass matrix = self loops
    # init DBLayerModul
    linDB = LinDBLayer(
        dims_n,
        dims_e;
        Δ=Δ, balanced=balanced, init_function=init_function)
    DBLayerModul(linDB, num_steps, Dropout(dropout_n), Dropout(dropout_e))
end


function (db::DBLayerModul)(g::GNNGraph, x_n, x_e)
    # forward
    for i in 1:db.num_steps
        x_n, x_e = db.linDB(g, x_n, x_e)
        x_n = db.dropout_n(x_n)
        x_e = db.dropout_e(x_e)
        x_n = relu.(x_n)
        x_e = relu.(x_e)
    end
    return x_n, x_e
end

## DBmultGATModel 
struct DBmultGAT_Model
    ud_n_in
    ud_e_in
    dbl
    steps_dbl
    dropout_n
    dropout_e
    n::Int
    dense_after_db
    GAT
    dense_after_gat
    dense
    Δ::Float32
    pool
end

Flux.@functor DBmultGAT_Model

"""
    DBmultGAT_Model(dims_n::Vector{Int}, dims_e::Vector{Int}, n; balanced=false, dropout_n=0.05, dropout_e=0.05, dense_after_db_dim=false, dense_after_gat_dim=false, Δ, use_GAT, GAT_heads, GAT_bias, GAT_concat, GAT_negative_slope, GAT_output_n_dim, pool=false)

Generates DBmultGAT_Model

- `dims_n::Vector{Int}`: Vector with dimensions for nodal embeddings of DB Layer
- `dims_n::Vector{Int}`: Vector with dimensions for edge embeddings of DB Layer
- `balanced` Bool for using balanced setup (weight sharing)
- `dropout_n`: node dropout
- `dropout_e`: edge dropout
- `dense_after_db_dim`: dimension of dense layer after DB, set false if no dense layer
- `dense_after_gat_dim`: dimension of dense layer after GAT, set false if no dense layer
- `Δ`: Δ value, must be float, mass matrix of DB layer
- `use_GAT`: use GAT or not
- `GAT_heads`: number of heads for GAT
- `GAT_bias`: use bias or not
- `GAT_concat`: concat heads or not
- `GAT_negative_slope`: negative slope of Leaky Relu for GAT
- `GAT_output_n_dim`: output dimension for GAT
- `pool`: use pooling for graph-based setup (Flux.mean or Flux.max) for graph-based setup or false for node setup
- `target_classes`: number of target classes

"""
function DBmultGAT_Model(dims_n::Vector{Int}, dims_e::Vector{Int}, num_dbl, steps_dbl; balanced=false, dropout_n=0.05, dropout_e=0.05, dense_after_db_dim=false, dense_after_gat_dim=false, Δ, use_GAT, GAT_heads, GAT_bias, GAT_concat, GAT_negative_slope, GAT_output_n_dim, pool=false, target_classes, init_function)
    @assert length(dims_n) == 2
    @assert length(dims_e) == 2
    # @assert length(dims_n) == length(dims_e)
    Δ = Δ / Float32(steps_dbl)
    ud_n_in = UDLayer(
        dims_n[1],
        dims_n[2],
        init_function)
    ud_e_in = UDLayer(
        dims_e[1],
        dims_e[2],
        init_function)
    dbl_deep = Array{DBLayerModul,1}(undef, num_dbl)
    for i in 1:num_dbl
        dbl_deep[i] = DBLayerModul(dims_n[2], dims_e[2], Δ, balanced, steps_dbl, dropout_n, dropout_e, init_function)
    end
    # dense_after_db
    dense_after_db_dim_in = dims_n[2]
    dense_after_db_dim_out = dense_after_db_dim
    if dense_after_db_dim != false
        dense_after_db = Dense(dense_after_db_dim_in, dense_after_db_dim_out, init =init_function)
    else
        dense_after_db = false
        dense_after_db_dim_out = dims_n[2]
    end

    if dense_after_db_dim != false
        GAT_n_in = dense_after_db_dim_out
    else
        GAT_n_in = dims_n[2]
    end

    if use_GAT
        GAT = GATv2Conv((GAT_n_in, dims_e[2]) => GAT_output_n_dim, heads=GAT_heads, concat=GAT_concat, bias=GAT_bias, negative_slope=GAT_negative_slope, add_self_loops=false)
        dense_after_gat_dim_in = GAT_output_n_dim * GAT_heads
    else
        dense_after_gat_dim = false
        GAT_output_n_dim = dense_after_db_dim_out
        GAT = false
        dense_after_gat_dim_in = dense_after_db_dim_out
    end

    if dense_after_gat_dim != false
        dense_after_gat_dim_out = dense_after_gat_dim
        dense_after_gat = Dense(dense_after_gat_dim_in, dense_after_gat_dim_out, init = init_function)
    else
        dense_after_gat = false
        if use_GAT == false
            dense_after_gat_dim_out = GAT_output_n_dim
        else
            dense_after_gat_dim_out = GAT_output_n_dim * GAT_heads
        end
    end

    dense = Dense(dense_after_gat_dim_out, target_classes, init=init_function)

    if pool != false
        pool = GlobalPool(pool)
    end
    DBmultGAT_Model(ud_n_in, ud_e_in, dbl_deep, steps_dbl, Dropout(dropout_n), Dropout(dropout_e), steps_dbl, dense_after_db, GAT, dense_after_gat, dense, Δ, pool)
end

function (model::DBmultGAT_Model)(g::GNNGraph, x_n, x_e)
    x_n = model.ud_n_in(x_n)
    x_e = model.ud_e_in(x_e)
    num_dbl = length(model.dbl)
    for i in 1:num_dbl
        x_n, x_e = model.dbl[i](g, x_n, x_e)
    end
    # for i in 1:model.n
    #     x_n, x_e = model.dbl(g, x_n, x_e)
    #     x_n = model.dropout_n(x_n)
    #     x_e = model.dropout_e(x_e)
    #     x_n = relu.(x_n)
    #     x_e = relu.(x_e)
    # end
    if model.dense_after_db != false
        x_n = model.dense_after_db(x_n)
    end
    if model.GAT != false
        x_n = model.GAT(g, x_n, x_e)
    end
    if model.pool != false
        x_n = model.pool(g, x_n)
    end
    if model.dense_after_gat != false
        x_n = model.dense_after_gat(x_n)
    end
    model.dense(x_n)'
end

