using CombDiff

f, _ = @main begin
    mvp = (A::RM, x::RV) -> (i::N) -> sum(j, A(i, j) * x(j))
    vip = (x::RV, y::RV) -> sum(i, x(i) * y(i))
    (batch::RV, Relu::RF) ->
        pullback((w_1::RM, w_2::RM, w_3::RV) ->
            (((x::RV) -> mvp(w_1, x)) ▷
             ((x::RV) -> (i::N) -> Relu(x(i))) ▷
             ((x::RV) -> mvp(w_2, x)) ▷
             ((x::RV) -> (i::N) -> Relu(x(i))) ▷
             ((x::RV) -> vip(x, w_3)))(
                batch
            ))
end

df = vdiff(f)
edf = eval_all(eval_pullback(deprimitize(eval_all(df))))


