using DifferentialEquations
using Flux
using SciMLSensitivity
using ProgressBars
using Random
using StatsBase: mean, var, std
using Plots
using LaTeXStrings
using RecursiveArrayTools
using LinearAlgebra
using BSON
pgfplotsx()

if isdir("src")
    cd("src")
end
using CSV, DataFrames


########## SDE Definition ################
α = [0.001f0]
β = [0.01f0]
NN_f = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
NN_σ = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), 
                                Dense(32, 1, tanh), x -> abs.(x))
θ_f, re_f = Flux.destructure(NN_f)
θ_σ, re_σ = Flux.destructure(NN_σ)


#### Out-of-place version
function f(u, p, t)
    return (p.x[3][1]) .* re_f(p.x[1])(u)
end
function σ(u, p, t)
    return (p.x[4][1]) .* re_σ(p.x[2])(u)
end
###### initial condition
Random.seed!(1234)
include("nsde_utils.jl")


######## Trajectory Information ###########
if isfile("../data/nsde_ground_truth.csv")
     df = CSV.read("../data/nsde_ground_truth.csv", DataFrame)
else
    @error "no ground truth found"
end

# convert df to U 
times = df.t |> Vector
times = Float32.(times)
tspan = (times[1], times[end])
U = hcat([df[!,Symbol("u$i")] for i in 1:100]...) # 201 x 100
U = Float32.(U)
u0 = (U[1,:]) # 100

######### Training wrt Losses ##################
# @info "test losses"
loss_W₂²(θ_f, θ_σ, α, β)
# loss_mse(θ_f, θ_σ, α, β)
# loss_avg_std(θ_f, θ_σ, α, β)
# loss_likelihood(θ_f, θ_σ, α, β)
# opt = ADAMW(0.01)
# prog_bar = ProgressBar(1:1000)
∇Θ = gradient(Flux.params(θ_f, θ_σ, α, β)) do
    loss_W₂²(θ_f, θ_σ, α, β)
end
# ∇Θ = gradient(Flux.params(θ_f, θ_σ, α, β)) do
#     loss_mse(θ_f, θ_σ, α, β)
# end
# ∇Θ = gradient(Flux.params(θ_f, θ_σ, α, β)) do
#     loss_avg_std(θ_f, θ_σ, α, β)
# end
# ∇Θ = gradient(Flux.params(θ_f, θ_σ, α, β)) do
#     loss_likelihood(θ_f, θ_σ, α, β)
# end

function remove_in_string(str)
    # remove "_" in string
    str = replace(str, "_" => " ")
    return str
end


result = train!(loss_W₂²)
using BSON
BSON.@save "../data/result_$(loss_W₂²).bson" result
BSON.@load "../data/result_$(loss_W₂²).bson" result
θ_f, θ_σ, α, β, losses = result
checkpoint_process(θ_f, θ_σ, α, β, loss_W₂², 1000; label = "loss_W₂²")
# savefig("../figures/result_$(loss_W₂²).tex")
CSV.write("../data/losses_$(loss_W₂²).csv", DataFrame(losses=losses))
plot(
    losses, 
    legend=false, 
    size=(400,400),
    xlabel="epoch",
    ylabel="training loss",
    title="$(loss_W₂²)" |> remove_in_string,
)

savefig("../figures/losses_$(loss_W₂²).pdf")