using DifferentialEquations
using Flux
using SciMLSensitivity
using ProgressBars
using Random
using StatsBase: mean, var, std
using Plots
using LaTeXStrings
pgfplotsx()

if isdir("src")
    cd("src")
end
using CSV, DataFrames
using BSON
include("nsde_utils.jl")

########## SDE Definition ################
function f!(du, u, p, t)
    du .= -cos.(u) .+ p[1]
end

function σ!(du, u, p, t)
    du .= p[2]
end
#### Out-of-place version
function f(u, p, t)
    return -cos.(u) .+ p[1]
end
function σ(u, p, t)
    if typeof(u) <: AbstractArray
        return p[2] * ones(size(u))
    end
    return p[2]
end
###### initial condition
Random.seed!(1234)
u0 = 0.1 * ones(100) # the julia implementation of backpropagation does not handle u0 if it is very high dimensional
###### parameters
p = [0.5, 1.0]
###### time interval
T = 20.0f0
tspan = (0.0f0, T)
times = 0.0f0:0.1f0:T
###### specify the SDE problem
prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p)
# prob = ODEProblem(f!, u0, tspan, p)
###### solve the SDE problem
sol = solve(prob, saveat=times)
U = sol.u
###########################################

######## Trajectory Information ###########
if isfile("../data/nsde_ground_truth.csv")
    @info "skipping saving ground truth"
else
    @info "saving ground truth as csv to ../data/nsde_ground_truth.csv"
    # create dataframe with columns t, u1, u2, ...
    df = DataFrame(t = sol.t)
    for i in 1:length(u0)
        df[!, "u$i"] = [u[i] for u in U]
    end
    CSV.write("../data/nsde_ground_truth.csv", df)
    H(x) = sin(x) - p[2] * x
    plot(-2:0.1:8, H.(-2:0.1:8), label="H(x)")
    plot!(U[end], H.(U[end]), seriestype=:scatter, label="U(t)")
    plot(legend=false, size=(400,400))
    for i in 1:40
        plot!(sol.t, [u[i] for u in U], label="U(t)", color=:gray, alpha=0.5)
    end
    xlabel!(L"t")
    plot!(times, mean.(U), ribbon=std.(U), label="mean ± std", width=2, color=:black)
    ylabel!(L"X(t)")
    xlims!(0,T)
    title!("SDE ground truth")
    ylims!(-5,15)
    savefig("../figures/nsde_ground_truth.pdf")
end

function Û()
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p)
# prob = ODEProblem(f!, u0, tspan, p)
###### solve the SDE problem
sol = solve(prob, saveat=times)
return sol.u
end


Matrix_U = hcat(U...) |> transpose |> Matrix
Matrix_Û = hcat( Û()...) |> transpose |> Matrix

@info "MSE of U and Û: " ∑(abs2, Matrix_U - Matrix_Û) / length(Matrix_U)

function loss_avg_std(U,Û)
    𝔼u = mean(U, dims=2) |> vec
    𝔼û = mean(Û, dims=2) |> vec
    varu = var(U, dims=2) |> vec
    varû = var(Û, dims=2) |> vec
    return mean(abs2, 𝔼u - 𝔼û) + mean(abs, varu - varû)
end

@info "loss_avg_std of U and Û: " loss_avg_std(Matrix_U, Matrix_Û)


function loss_W₂²(U,Û)
    return ∑(W₂²(Û[t,:], U[t,:]) for t ∈ eachindex(times)) / length(times)
end

@info "loss_W₂² of U and Û: " loss_W₂²(Matrix_U, Matrix_Û)

