using DifferentialEquations
using Flux
using SciMLSensitivity
using ProgressBars
using Random
using StatsBase: mean, var, std
using Plots
using LaTeXStrings
pgfplotsx()

if isdir("src")
    cd("src")
end
using CSV, DataFrames

include("utils.jl")

########## SDE Definition ################
function f!(du, u, p, t)
    du .= -cos.(u) .+ p[1]
end

function σ!(du, u, p, t)
    du .= p[2]
end
#### Out-of-place version
function f(u, p, t)
    return -cos.(u) .+ p[1]
end
function σ(u, p, t)
    if typeof(u) <: AbstractArray
        return p[2] * ones(size(u))
    end
    return p[2]
end
###### initial condition
Random.seed!(1234)
u0 = randn(40) # the julia implementation of backpropagation does not handle u0 if it is very high dimensional
###### parameters
p = [0.5, 1.0]
###### time interval
T = 20.0
tspan = (0.0, T)
times = 0.0:0.1:T
###### specify the SDE problem
prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p)
# prob = ODEProblem(f!, u0, tspan, p)
###### solve the SDE problem
sol = solve(prob, saveat=times)
U = sol.u
###########################################

######## Trajectory Information ###########
if isfile("../data/sde_ground_truth.csv")
    @info "skipping saving ground truth"
else
    @info "saving ground truth as csv to ../data/sde_ground_truth.csv"
    # create dataframe with columns t, u1, u2, ...
    df = DataFrame(t = sol.t)
    for i in 1:length(u0)
        df[!, "u$i"] = [u[i] for u in U]
    end
    CSV.write("data/sde_ground_truth.csv", df)
    H(x) = sin(x) - p[2] * x
    plot(-2:0.1:8, H.(-2:0.1:8), label="H(x)")
    plot!(U[end], H.(U[end]), seriestype=:scatter, label="U(t)")
    plot(legend=false, size=(400,400))
    for i in 1:40
        plot!(sol.t, [u[i] for u in U], label="U(t)", color=:gray, alpha=0.5)
    end
    xlabel!(L"t")
    plot!(times, mean.(U), ribbon=std.(U), label="mean ± std", width=2, color=:black)
    ylabel!(L"X(t)")
    xlims!(0,T)
    title!("SDE ground truth")
    ylims!(-5,15)
    savefig("figures/sde_ground_truth.pdf")
end

######### Training wrt Losses ##################
@info "Training using different losses"
U0 = vcat(repeat(u0, 2)) # in the case of W₂² distance, we can generate a different number of samples

function evaluate_save_plot(loss_func, loss_name)
    if isfile("../data/params_trajectory_$(loss_name)_example.csv")
        @info "$(loss_name) example already exists, skipping..."
    else
        @info "training for $(loss_name) example"
        Random.seed!(56789)
        p′ = rand(2)
        params = train!(loss_func, p′)
        CSV.write("../data/params_trajectory_$(loss_name)_example.csv", DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
        @info "optimal loss for $(loss_name) is $(mean(loss_func(p′) for i in 1:10))"
        prob = SDEProblem(SDEFunction(f!,σ!), σ!, U0, tspan, p′)
        # prob = ODEProblem(f!, u0, tspan, p′)
        sol = solve(prob,  saveat=times)
        U′ = sol.u
        plot(legend=false, size=(400,400))
        for i in 1:80
            plot!(sol.t, [u[i] for u in U′], label="U(t)", color=:gray, alpha=0.5)
        end
        xlabel!(L"t")
        plot!(times, mean.(U′), ribbon=std.(U′), label="mean ± std", width=2, color=:black)
        ylabel!(L"X(t)")
        xlims!(0,T)
        ylims!(-5,15)
        title!("$(loss_name) loss minimization")
        savefig("../figures/sde_$(loss_name)_loss_minimization.pdf")
    end
end

loss_name = "W2"
evaluate_save_plot(loss_W₂², "W2")
evaluate_save_plot(loss_avg_std, "avg_std")
evaluate_save_plot(loss_mse, "mse")
evaluate_save_plot(loss_likelihood, "loglikelihood")


######### Special pathological example ##########
if isfile("../data/params_trajectory_W2_pathological_example.csv")
    @info "skipping running pathological example"
else
    p′ = [0.9,0.1]
    Random.seed!(56789)
    @show params = train!(loss_W₂², p′)
    CSV.write("../data/params_trajectory_W2_pathological_example.csv", 
        DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
end

if isfile("../data/params_trajectory_avg_std_pathological_example.csv")
    p′ = [0.9,0.1]
    Random.seed!(56789)
    @show p′ = train!(loss_avg_std, p′)
    CSV.write("../data/params_trajectory_W2_pathological_example.csv", 
        DataFrame(p1 = [p[1] for p in p′], p2 = [p[2] for p in p′]))
end

####### Heatmap of loss function ########
# when p₁ = 0.0:0.01:1.0 and p₂ = 0.0:0.01:1.0
# plot the loss function for p₁ = 0.0:0.01:1.0 and p₂ = 0.0:0.01:1.0
function heatmap(loss_func, loss_name)
    if isfile("../data/sde_$(loss_name)_loss.csv")
        @info "skipping saving $(loss_name) loss"
    else
        @info "saving $(loss_name) loss as csv to ../data/sde_$(loss_name)_loss.csv"
        loss_df = DataFrame(
            loss = [],
            p₁ = [],
            p₂ = []
        )
        loss = [.0 for p₁ in 0.01:0.01:1.0, p₂ in 0.01:0.02:2.0]
        for (i,p₁) in enumerate(0.01:0.01:1.0), (j,p₂) in enumerate(0.01:0.02:2.0)
            lossᵢ = loss_func([p₁,p₂])
            push!(loss_df, [lossᵢ, p₁, p₂])
            loss[i,j] = lossᵢ
        end
        Plots.heatmap(0.01:0.02:2.0, 0.01:0.01:1.0, log.(loss), 
            color=:hot,
            ylabel=L"p_1",
            xlabel=L"p_2",
            title="$(loss_name) loss",
            size=(400,400))
        savefig("../figures/sde_$(loss_name)_loss.pdf")

        CSV.write("../data/sde_$(loss_name)_loss.csv", loss_df)
    end
end

heatmap(loss_W₂², "W₂²")
heatmap(loss_avg_std, "avg_std")
heatmap(loss_mse, "mse")
heatmap(loss_likelihood, "loglikelihood")

#############################################