using DifferentialEquations
using Random
# CIX model 
# dX_t = a (b - X_t) dt + σ √(X_t) dW_t
function f(u, p, t)
    a, b, σ = p
    return a.*( - u .+ b) 
end

function σ(u, p, t)
    a, b, σ = p
    return σ * sqrt.(u)
end
∑ = sum
# implementation of the 2-Wasserstein distance
function quantile(samples)
    # return a function that computes the quantile of a given sample
    samples_sorted = sort(samples)
    return p -> samples_sorted[floor(Int, p * length(samples_sorted))+1]
end
function W₂²(u_samples, v_samples)
    u_icdf_grids = [i / length(u_samples) for i in 0:length(u_samples)]
    v_icdf_grids = [i / length(v_samples) for i in 0:length(v_samples)]
    grids = unique([u_icdf_grids; v_icdf_grids]) |> sort
    U_icdf = quantile(u_samples).(grids[1:end-1])
    V_icdf = quantile(v_samples).(grids[1:end-1])
    return sum((U_icdf - V_icdf).^2 .* diff(grids))
end

Random.seed!(456)
# initial condition
u0 = Float32.(2 .+ 0.5randn(200))

# parameters
p = [1, 5, 0.5] 

tspan = (0.0f0, 2.0f0)
times = [0.0f0:0.05f0:2.0f0;]

prob = SDEProblem(f, σ, u0, tspan, p)
sol = solve(prob, saveat=times)

using CSV
using DataFrames
df = DataFrame(sol)

tU = Matrix(df)
U = Float32.(tU[:, 2:end])
U = reshape(U', 200, 41)
# reshape to d × N × n(tsteps), d == prob dimesion, N == number of
# trajectories, n(tsteps) == number of time steps
# U[1, :, 1] == u0
u0
t = tU[:, 1]

using Flux
using DiffEqFlux

NN_f = Chain(
    Dense(1, 32, relu),
    Dense(32, 32, relu),
    Dense(32, 1)
)

NN_σ = Chain(
    Dense(1, 32, relu),
    Dense(32, 32, relu),
    Dense(32, 1)
)

θ_f, re_f = Flux.destructure(NN_f)
θ_σ, re_σ = Flux.destructure(NN_σ)

using RecursiveArrayTools
θ = ArrayPartition(θ_f, θ_σ)
f̂(u, p, t) = re_f(p.x[1])(reshape(u, 1,:)) |> vec
σ̂(u, p, t) = re_σ(p.x[2])(reshape(u, 1,:)) |> vec
function predict(θ)
    nsdeproblem = SDEProblem(f̂, σ̂, u0, tspan, θ)
    return solve(nsdeproblem, saveat=times, 
                EM(),
                abstol=1f-1, reltol=1f-1, dt = 1f-2,
                ) |> Array # d × N × n(tsteps)
end
using Statistics
means_U = mean(U, dims=1) # d × 1 × n(tsteps)
vars_U = var(U, dims = 1, mean = means_U)

function loss_avg_std(θ)
    Û = predict(θ)
    means_Û = mean(Û, dims=1) # d × 1 × n(tsteps)
    vars_Û = var(Û, dims = 1, mean = means_Û)
    return (sum((means_Û - means_U).^2) + sum(abs.(vars_Û - vars_U)))/length(means_U),
            means_Û, vars_Û
end

function loss_W2(θ)
    Û = predict(θ)
    means_Û = mean(Û, dims=1) # d × 1 × n(tsteps)
    vars_Û = var(Û, dims = 1, mean = means_Û)
    return mean([W₂²(Û[:,t], U[:,t]) for t in axes(U, 2)]),
            means_Û, vars_Û
end

@time loss_avg_std(θ)
@time loss_W2(θ)

@time gradient(Flux.params(θ)) do
    loss_avg_std(θ)[1]
end

@time gradient(Flux.params(θ)) do
    loss_W2(θ)[1]
end


using Plots
using ProgressBars
prog = ProgressBar(1:1000)
iter = 0
function cb(p,l, means, vars; doplot = true)
    global list_plots, iter 
    if iter == 0
        list_plots = []
    end
    iter += 1
    # @show l 
    xs = range(0, 7, length = 100)
    # @show iter
    plt = Plots.plot(times, mean(U, dims = 1)[:], ribbon = std(U, dims = 1)[:], 
        label = "true", color = :black, legend = :topleft)
    Plots.plot!(plt, times, means[:], ribbon = sqrt.(vars)[:], label = "pred", color = :red)
    plt_f = Plots.plot(xs, [re_f(θ.x[1])([Float32(x)])[1] for x in xs], 
            label = "NN_f", color = :red,
            # dash line 
            linestyle = :dash, linealpha = 0.5, linewidth = 2,
            xlabel = "x", ylabel = "f(x)",
            ylims=(-2, 5),
            title = "iter = $iter",
            marker = :circle,
            )

    plot!(plt_f, xs, [f(x, p,0.0) for x in xs],
         label = "f", color = :black)
    plg_g = Plots.plot(xs, [abs(re_σ(θ.x[2])([Float32(x)])[1]) for x in xs], 
            label = "NN_σ", color = :red,
            # dash line
            linestyle = :dash, linealpha = 0.5, linewidth = 2,
            xlabel = "x", ylabel = "σ(x)",
            ylims=(0, 3),
            title = "iter = $iter",
            marker = :circle,
            )
    plot!(plg_g, xs, [σ(x, p,0.0) for x in xs], label = "σ", color = :black)
    plt = plot(plt, plt_f, plg_g, layout = (3,1), size = (800, 1200))
    if doplot
        display(plt)
    end
    push!(list_plots, plt)
    set_multiline_postfix(prog, "loss = $l")
    return false
end
opt = ADAMW(0.002, (0.9, 0.999), 0.001)
for i in prog
    ∇θ = gradient(Flux.params(θ)) do
        loss_W2(θ)[1]
    end
    Flux.update!(opt, Flux.params(θ), ∇θ)
    if i % 10 == 0
        cb(p, loss_W2(θ)...)
    end
end 

anim = @animate for i in 1:length(list_plots)
    list_plots[i]
end
gif(anim, "loss_W2.gif", fps = 5)