using CombDiff

z = 1.0
add_one(x::Float64) = x + 1
add_one(x::Vector{Float64}) = x .+ 1

pullback_add_one(::Float64, k::Float64) = k
CombDiff.find_pullback(::Val{add_one}) = pullback_add_one


f, _ = @main :(pullback((x::R) -> add_one(x)))
f, _ = @main :(pullback((x::RV) -> add_one(x)))

N_e = 5

f, ctx = @comb :(
    begin
        @space S3{T} begin
            type = (N{T}, N{T}, N{T}) -> R
            symmetries = (
                t -> (i, j, k) -> t(j, i, k),
                t -> (i, j, k) -> t(i, k, j),
            )
        end
        mvp = {M, N}(A::RM{M,N}, x::RV{N}) -> (i::N{M}) -> sum((j::N{N}), A(i, j) * x(j))
        plus_one = (x::R) -> add_one(x)

        @domain occ N{N_e}
    end
)

g, inputs = @main :(
    pullback((z::N) -> plus_one(z))(1.0, 1.0)
) f ctx

simplify(eval_pullback(g); settings=custom_settings(:link_pullback => true)) |> first |> eval_pullback |> eval_all

eval(simplify(eval_pullback(g); settings=custom_settings(:link_pullback => true)) |> first |> eval_pullback |> eval_all |> codegen)(inputs...)

func = eval(eval_pullback(g) |> simplify |> first |> eval_pullback |> codegen)
func(inputs...)(1.0, 1.0)

g, inputs = @main :(
    (M::N) ->
        (X::S3{M}, y::RV{M}) -> (i::N, j::N) -> mvp((k::N{M}) -> 2 * X(i, j, k), y)
) f ctx

g, inputs = @main :(
    grad((x::RV) -> sum(i, x(i) * x(i)))
) f ctx


