
∑ = sum
# implementation of the 2-Wasserstein distance
function quantile(samples)
    # return a function that computes the quantile of a given sample
    samples_sorted = sort(samples)
    return p -> samples_sorted[floor(Int, p * length(samples_sorted))+1]
end
function W₂²(u_samples, v_samples)
    # adapted from https://github.com/nklb/wasserstein-distance
    u_samples_sorted = sort(u_samples);
    v_samples_sorted = sort(v_samples);    
    u_icdf_grids = [i / length(u_samples) for i in 0:length(u_samples)]
    v_icdf_grids = [i / length(v_samples) for i in 0:length(v_samples)]
    grids = unique([u_icdf_grids; v_icdf_grids]) |> sort
    U_icdf = quantile(u_samples).(grids[1:end-1])
    V_icdf = quantile(v_samples).(grids[1:end-1])
    return sum((U_icdf - V_icdf).^2 .* diff(grids))
end

function checkpoint_process(θ_f, θ_σ, α, β, loss_func, i; label="$(loss_func)")
    Û = predict(u0, θ_f, θ_σ, α, β)
    # save the result as CSV DataFrame in the same format as sde_ground_truth
    result_df = DataFrame(t=times)
    for i in axes(Û,2)
        result_df[!,Symbol("u$i")] = Û[:,i]
    end
    CSV.write("../data/result_$(label)_$i.csv", result_df)
    plot(legend=false, size=(400,400))
    for i in axes(Û,2)
        plot!(times, Û[:,i],color=:gray, alpha=0.5)
    end
    xlabel!(L"t")
    plot!(times, mean(Û, dims=2), ribbon= std(Û, dims=2), color=:black, label="mean ± std", width=2)
    ylabel!(L"X(t)")
    xlims!(0,tspan[2])
    title!("$(label)_$i"|> remove_in_string)
    ylims!(-5,15)
    savefig("../figures/result_$(label)_$i.pdf")
    BSON.@save "../data/result_$(label)_$i.bson" θ_f θ_σ α β
end
function predict(u::Number, θ_f, θ_σ, α, β)
    p = ArrayPartition(θ_f,θ_σ,α,β)
    Prob = SDEProblem(f, σ, [u], tspan, p)
    sol = solve(Prob,EM(), saveat=times,verbose=false,dt=0.1f0)
    trajectories = sol.u
    return hcat(trajectories...) |> transpose |> Matrix # 201 x 1
end

function predict(U::Array, θ_f, θ_σ, α, β)
    Us = [predict(u, θ_f, θ_σ, α, β) for u in U]
    return hcat(Us...) # 201 x length(U)
end
# Us = predict(u0, θ_f, θ_σ, α, β) |> (x -> hcat(x...)) # 201 x 40

function loss_W₂²(θ_f, θ_σ, α, β)
    Û = predict(u0, θ_f, θ_σ, α, β)
    return ∑(W₂²(Û[t,:], U[t,:]) for t ∈ eachindex(times)) / length(times)
end


function loss_mse(θ_f, θ_σ, α, β)
    Û = predict(u0, θ_f, θ_σ, α, β)
    return ∑(abs2, U - Û) / length(U)
end

function loss_avg_std(θ_f, θ_σ, α, β)
    Û = predict(u0, θ_f, θ_σ, α, β)
    𝔼u = mean(U, dims=2) |> vec
    𝔼û = mean(Û, dims=2) |> vec
    varu = var(U, dims=2) |> vec
    varû = var(Û, dims=2) |> vec
    return mean(abs2, 𝔼u - 𝔼û) + mean(abs, varu - varû)
end

function ℓ̂(U, times, f, σ, p)
    # must use out-of-place version of f and σ
    Uₜ = [[U[t,i] for t in eachindex(times)] for i in 1:size(U,2)]
    i = 1
    ℓ = 0.0
    for uₜ ∈ Uₜ
        for i ∈ 1:length(times)-1
            dt = times[i+1] - times[i]
            n = typeof(uₜ[1]) <: AbstractArray ? length(uₜ[1]) : 1
            dx╱dt = f([uₜ[i]], p, times[i])
            dσ╱dt = σ([uₜ[i]], p, times[i])
            if n == 1
                x = uₜ[i+1] - uₜ[i] - dx╱dt[1] * dt
                Σ = dσ╱dt[1]^2 * dt * I(n)
            else
                x = uₜ[i+1] - uₜ[i] - dx╱dt * dt
                Σ = dσ╱dt.^2 .* I(n) * dt
            end
            # remark: the implementation of covariance matrix Σ from σ is subtle
            # they assume diagonal noises by default 
            ℓ += -0.5 * n * log(2π) - 0.5 * logdet(Σ) - 0.5 * (x' * inv(Σ) * x)[]
        end
    end
    return ℓ
end
function loss_likelihood(θ_f, θ_σ, α, β)
    p = ArrayPartition(θ_f, θ_σ, α, β)
    -ℓ̂(U, times, f, σ, p) / length(U)
end

function train!(loss_func; n_epochs = 1000)
    Random.seed!(56789)
    θ_f₀ = randn(Float32,size(θ_f))
    θ_σ₀ = randn(Float32,size(θ_σ))
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(θ_f₀, θ_σ₀, α, β) for i in 1:1)
    # constructed a threaded version 
    opt = ADAMW(0.001)
    prog_bar = ProgressBar(1:n_epochs)
    losses = []
    # store the best parameters
    # best_p′ = copy(θ_f₀, θ_σ₀)
    # best_loss = ∑(loss_func(θ_f₀, θ_σ₀) for i in 1:1)
    for i in prog_bar
        ∇Θ = gradient(Flux.params(θ_f₀, θ_σ₀, α, β)) do
            ∑(loss_func(θ_f₀, θ_σ₀, α, β) for i in 1:1)
        end
        Flux.update!(opt, Flux.params(θ_f₀, θ_σ₀, α, β), ∇Θ)
        current_loss = ∑(loss_func(θ_f₀, θ_σ₀, α, β) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(θ_f₀, θ_σ₀)
            # best_loss = current_loss
        # end
        push!(losses, current_loss)
        set_multiline_postfix(prog_bar,
        "loss=$(current_loss)")
        if i % 100 == 0
            # plotting
            checkpoint_process(θ_f₀, θ_σ₀, α, β, loss_func,i , label="$(loss_func)")
        end
        GC.gc()
    end
    return (θ_f₀, θ_σ₀, α, β, losses)
end
function train!(loss_func, θ_f₀, θ_σ₀, α, β)
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(θ_f₀, θ_σ₀, α, β) for i in 1:1)

    # constructed a threaded version 
    opt = ADAM(0.001)
    params = []
    prog_bar = ProgressBar(1:1000)
    # store the best parameters
    # best_p′ = copy(θ_f₀, θ_σ₀)
    # best_loss = ∑(loss_func(θ_f₀, θ_σ₀) for i in 1:1)
    @show p′
    for i in prog_bar
        ∇Θ = gradient(Flux.params(θ_f₀, θ_σ₀)) do
            ∑(loss_func(θ_f₀, θ_σ₀) for i in 1:1)
        end
        Flux.update!(opt, Flux.params(θ_f₀, θ_σ₀), ∇Θ)
        current_loss = ∑(loss_func(θ_f₀, θ_σ₀) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(θ_f₀, θ_σ₀)
            # best_loss = current_loss
        # end
        set_multiline_postfix(prog_bar,
        "loss=$(current_loss)\n")
        push!(params, (copy(θ_f₀), copy(θ_σ₀)))
        GC.gc()
    end
    return params
end

