using CombDiff
g, ctx_g = @comb :(
    begin
        (n_spatial::N, N_e::N, N_sb::N) -> _
    end
)
g, _ = @comb :(
    begin
        @domain S N{2 * n_spatial}
        @domain occ N{N_e}
        @domain vir N{N_e + 1,2 * n_spatial}

        @space T2 begin
            type = (vir, vir, occ, occ) -> R
            symmetries = (((2, 1, 3, 4), :neg), ((1, 2, 4, 3), :neg))
        end

        @space T4 begin
            type = (vir, vir, vir, vir, occ, occ, occ, occ) -> R
            symmetries = (((2, 1, 3, 4, 5, 6, 7, 8), :neg), ((1, 3, 2, 4, 5, 6, 7, 8), :neg), ((1, 2, 4, 3, 5, 6, 7, 8), :neg), ((4, 2, 3, 1, 5, 6, 7, 8), :neg),
                ((1, 2, 3, 4, 6, 5, 7, 8), :neg), ((1, 2, 3, 4, 5, 7, 6, 8), :neg), ((1, 2, 3, 4, 5, 6, 8, 7), :neg), ((1, 2, 3, 4, 8, 6, 7, 5), :neg),
            )
        end

        fd = (t_2::T2) -> (a::vir, b::vir, i::occ, j::occ) -> t_2(a, b, i, j)
        fq = (t_2::T2) -> (a::vir, b::vir, c::vir, d::vir, i::occ, j::occ, k::occ, l::occ) -> ((-1.0) * t_2(a, c, j, k) * t_2(b, d, i, l) + 1.0 * t_2(c, d, k, l) * t_2(a, b, i, j) + (-1.0) * t_2(c, d, j, l) * t_2(a, b, i, k) + 1.0 * t_2(a, d, j, k) * t_2(b, c, i, l) + (-1.0) * t_2(c, d, i, k) * t_2(b, a, l, j) + (-1.0) * t_2(a, d, i, k) * t_2(b, c, j, l) + 1.0 * t_2(a, d, k, l) * t_2(b, c, i, j) + 1.0 * t_2(a, c, i, k) * t_2(b, d, j, l) + (-1.0) * t_2(a, c, k, l) * t_2(b, d, i, j) + 1.0 * t_2(a, d, i, j) * t_2(b, c, k, l) + 1.0 * t_2(a, c, j, l) * t_2(b, d, i, k) + (-1.0) * t_2(a, c, i, j) * t_2(b, d, k, l) + (-1.0) * t_2(a, d, j, l) * t_2(b, c, i, k) + 1.0 * t_2(c, d, i, l) * t_2(a, b, j, k) + (-1.0) * t_2(a, c, i, l) * t_2(b, d, j, k) + 1.0 * t_2(a, d, i, l) * t_2(b, c, j, k) + 1.0 * t_2(c, d, j, k) * t_2(a, b, i, l) + 1.0 * t_2(c, d, i, j) * t_2(a, b, k, l))

        pullback((t_2::T2) -> (a::vir, b::vir, c::vir, d::vir, i::occ, j::occ, k::occ, l::occ) -> ((-1.0) * t_2(a, c, j, k) * t_2(b, d, i, l) + 1.0 * t_2(c, d, k, l) * t_2(a, b, i, j) + (-1.0) * t_2(c, d, j, l) * t_2(a, b, i, k) + 1.0 * t_2(a, d, j, k) * t_2(b, c, i, l) + (-1.0) * t_2(c, d, i, k) * t_2(b, a, l, j) + (-1.0) * t_2(a, d, i, k) * t_2(b, c, j, l) + 1.0 * t_2(a, d, k, l) * t_2(b, c, i, j) + 1.0 * t_2(a, c, i, k) * t_2(b, d, j, l) + (-1.0) * t_2(a, c, k, l) * t_2(b, d, i, j) + 1.0 * t_2(a, d, i, j) * t_2(b, c, k, l) + 1.0 * t_2(a, c, j, l) * t_2(b, d, i, k) + (-1.0) * t_2(a, c, i, j) * t_2(b, d, k, l) + (-1.0) * t_2(a, d, j, l) * t_2(b, c, i, k) + 1.0 * t_2(c, d, i, l) * t_2(a, b, j, k) + (-1.0) * t_2(a, c, i, l) * t_2(b, d, j, k) + 1.0 * t_2(a, d, i, l) * t_2(b, c, j, k) + 1.0 * t_2(c, d, j, k) * t_2(a, b, i, l) + 1.0 * t_2(c, d, i, j) * t_2(a, b, k, l)))

        #= (cd::T2, cq::T4) -> pullback((t_2::T2) -> 
                                     sum((a::vir, b::vir, i::occ, j::occ), (fd(t_2)(a, b, i, j) - cd(a, b, i, j))^2) 
                                     + 
                                     sum((a::vir, b::vir, c::vir, d::vir, i::occ, j::occ, k::occ, l::occ), (fq(t_2)(a, b, c, d, i, j, k, l) - cq(a, b, c, d, i, j, k, l))^2) 
                                     ) =#
    end

) g ctx_g

eval_pullback(eval_all(g))
