using CombDiff

f, _ = @main :(
    begin
        f = (z::R) -> (v::R) -> z * v
        pullback((z::R) -> begin
            x = f(z)(2)
            y = f(z)(3)
            x + y
        end)
    end)

eval_pullback(f)
result = eval_pullback(f) |> simplify |> first |> deprimitize |> eval_pullback |> eval_all

CombDiff.strip_copy(result) |> eval_all |> eval_pullback |> simplify |> first
CombDiff.strip_copy(result) |> eval_all |> simplify |> first

a, _ = @comb :((x::N) -> x + 1)
b, _ = @comb :((y::N) -> y + 1)

a == b

eval_pullback(f)
decompose(get_body(let_copy_to_comp(f)))

f, _ = @comb :(pullback((z::R) ->
    (((z::R) -> z * 2) ▷
     ((x::R) -> (x * 3, z)) ▷
     ((y::R, x::R) -> (x + y)))(z)))

f, _ = @comb :((z::R) -> pullback((w::R) ->
    (((a::R) -> (a, w)) ▷
     ((a::R, b::R) -> (a * w, b)) ▷
     ((a::R, b::R) -> (b * a)))(z)))

result = eval_pullback(f) |> deprimitize |> eval_pullback |> eval_all |> simplify |> first
eval_all(CombDiff.strip_copy(f)) |> eval_pullback
eval_all(CombDiff.strip_copy(result)) |> simplify |> first


f, _ = @comb :((z::R) -> pullback((w::R) ->
    (((x::R) -> x * 2) ▷
     ((x::R) -> x * w) ▷
     ((x::R) -> (2 * x)))(z)))
eval_pullback(f)


f, _ = @comb :((z::R, f::RF, g::RF) -> pullback((w::R) ->
    (((x::R) -> f(x)) ▷
     ((x::R) -> x * g(w)) ▷
     ((x::R) -> (2 * x)))(z)))
eval_pullback(f)

eval_pullback(let_copy_to_comp(f))

eval_pullback(let_copy_to_call(f))

f, _ = @comb :((f::RF) -> pullback((z::R, y::R) -> (z, f(y))))

eval_pullback(f)


f, _ = @meinview :(
    (y::N{4}) ->
        begin
            f = {P::N}(x::N{P}) -> CombDiff.fft(x)
            f(y)
        end
)

f, _ = @meinview :(
    begin
        f = {P::N}(x::N{P}) -> CombDiff.fft(x)
        f(y)
    end
)

f, _ = @meinview :(
    begin
        f = pullback({P::N}(x::RV{P}) -> (i::N{P}) -> x(i)^2)
        f((i::N{10}) -> i)
    end
)



f, _ = @meinview :(
    pullback((w::R) ->
    begin
        f = {P::N}(x::RV{P}) -> sum((i::N{P}), w * x(i))
        f((i::N{10}) -> i)
    end)
)
let_copy_to_comp(f)

f, _ = @meinview :(pullback((x::RV) -> (i::N) -> x(i)^2))
eval_pullback(f)
f, _ = @comb :(
    begin
        @space T begin
            type = (RV,) -> R
        end
        pullback((w::R) ->
            (
                (() -> (x::RV) -> ∑((i), w * x(i))) ▷
                ((f::T) -> f((i::N) -> i))
            )()
        )
    end
)

eval_pullback(f)
eval_pullback(eval_all(CombDiff.strip_copy(f)))


eval_pullback(f)
