using DifferentialEquations
using Flux
using SciMLSensitivity
using ProgressBars
using Random
using StatsBase: mean, var, std
using Plots
using LaTeXStrings
using RecursiveArrayTools
using LinearAlgebra
using BSON
pgfplotsx()

if isdir("src")
    cd("src")
end
using CSV, DataFrames


########## SDE Definition ################
α = [0.001f0]
β = [0.001f0]
NN_f = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
NN_σ = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1, tanh), x -> abs.(x))
θ_f, re_f = Flux.destructure(NN_f)
θ_σ, re_σ = Flux.destructure(NN_σ)


#### Out-of-place version
function f̂(u, p, t)
    return (p.x[3][1]) .* re_f(p.x[1])(u)
end
function σ̂(u, p, t)
    return (p.x[4][1]) .* re_σ(p.x[2])(u)
end
##### Ground truth
function f(u, p, t)
    return -cos.(u) .+ p[1]
end
function σ(u, p, t)
    if typeof(u) <: AbstractArray
        return p[2] * ones(size(u))
    end
    return p[2]
end

###### initial condition
Random.seed!(1234)
include("nsde_utils.jl")

##### Loss funcs 
loss_funcs = [loss_mse, loss_avg_std, loss_W₂², loss_likelihood]
results = [BSON.load( "../data/result_$(loss_funcs[i])_1000.bson") for i in eachindex(loss_funcs)]

##### Plot reconstructed f̂ and σ̂, compare with ground truth
θ = [0.5, 0.9]
using PGFPlotsX

x_range = -5.0:0.1:5
X = collect(x_range)
F = [f(x, θ, 0.0) for x in X]
Σ = [σ(x, θ, 0.0) for x in X]

axis = @pgf Axis(
    {
        xlabel = "\$x\$",
        ylabel = "\$f(x)\$",
        width = "10cm",
        height = "10cm",
    }
)

CSV.write("../figures/nsde_f_σ.csv", DataFrame(x = X, f = F, σ = Σ))
t = @pgf Table(
    {
        x = "x",
        y = "f",
        col_sep = "comma",
    },
    "/mnt/shared/research/wasserstein/figures/nsde_f_σ.csv",
)

plot_line = @pgf Plot(
    {no_marks, color = "black"},
    t
)
legend_entry = @pgf LegendEntry("ground truth")
push!(axis, plot_line)
push!(axis, legend_entry)
markers = ["o","triangle","square","pentagon"]
sizes = [1.8,2,1.8,2]
colors = ["red", "orange", "green", "blue"]
loss_names = ["MSE", "AvgStd", "W2", "Likelihood"]
for (loss_func, result, marker, mark_size, color, loss_name) ∈ zip(loss_funcs, results, markers, sizes, colors, loss_names)
    p = ArrayPartition(result[:θ_f], result[:θ_σ], result[:α], result[:β])
    F̂ = [f̂([x],p,0.0f0)[1] for x in X]
    Σ̂ = [σ̂([x],p,0.0f0)[1] for x in X]
    CSV.write("../figures/nsde_f̂_σ̂_$(loss_func).csv", DataFrame(x = X, f̂ = F̂, σ̂ = Σ̂))
    t = @pgf Table(
        {
            x = "x",
            y = "f̂",
            col_sep = "comma",
        },
        "/mnt/shared/research/wasserstein/figures/nsde_f̂_σ̂_$(loss_func).csv",
    )
    plot_line = @pgf Plot(
        {
            color = color,
            mark = marker,
            mark_size = mark_size,
        },
        t
    )
    legend_entry = @pgf LegendEntry(loss_name)
    push!(axis, plot_line)
    push!(axis, legend_entry)
end
axis
pgfsave("../figures/nsde_f.tex", axis, include_preamble=false,)
pgfsave("../figures/nsde_f.pdf", axis)

# plot reconstructed σ̂
axis = @pgf Axis(
    {
        xlabel = "\$x\$",
        ylabel = "\$\\sigma(x)\$",
        width = "10cm",
        height = "10cm",
    }
)
t = @pgf Table(
    {
        x = "x",
        y = "σ",
        col_sep = "comma",
    },
    "/mnt/shared/research/wasserstein/figures/nsde_f_σ.csv",
)

plot_line = @pgf Plot(
    {no_marks, color = "black"},
    t
)
legend_entry = @pgf LegendEntry("ground truth")

push!(axis, plot_line)
push!(axis, legend_entry)

for (loss_func, result, marker, mark_size, color, loss_name) ∈ zip(loss_funcs, results, markers, sizes, colors, loss_names)
    p = ArrayPartition(result[:θ_f], result[:θ_σ], result[:α], result[:β])
    F̂ = [f̂([x],p,0.0f0)[1] for x in X]
    Σ̂ = [σ̂([x],p,0.0f0)[1] for x in X]
    CSV.write("../figures/nsde_f̂_σ̂_$(loss_func).csv", DataFrame(x = X, f̂ = F̂, σ̂ = Σ̂))
    t = @pgf Table(
        {
            x = "x",
            y = "σ̂",
            col_sep = "comma",
        },
        "/mnt/shared/research/wasserstein/figures/nsde_f̂_σ̂_$(loss_func).csv",
    )
    plot_line = @pgf Plot(
        {
            color = color,
            mark = marker,
            mark_size = mark_size,
        },
        t
    )
    legend_entry = @pgf LegendEntry(loss_name)
    push!(axis, plot_line)
    push!(axis, legend_entry)
end
axis
pgfsave("../figures/nsde_σ.tex", axis, include_preamble=false,)
pgfsave("../figures/nsde_σ.pdf", axis)


#### Plot trajectories and loss 


for (loss_func, result, loss_name) ∈ zip(loss_funcs, results, loss_names)
    axis = @pgf Axis(
        {
            xlabel = "\$t\$",
            ylabel = "\$x\$",
            width = "10cm",
            height = "10cm",
            title = "Reconstructed Dynamics with $(loss_name) Loss",
            ymin = -3,
            ymax = 15,
            xmin = 0,
            xmax = 20,
        }
    )
    df = CSV.read("../data/result_$(loss_func)_1000.csv", DataFrame)
    # 201×101 DataFrame
    #  Row │ t        u1         u2        u3        u4        u5        u6        u7         u8         u9        u10       u11       u12       u13       u14       u15        u16      ⋯
    #      │ Float64  Float64    Float64   Float64   Float64   Float64   Float64   Float64    Float64    Float64   Float64   Float64   Float64   Float64   Float64   Float64    Float64  ⋯
    # ─────┼──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    #    1 │
    t = df[!, :t]
    ui_columns = names(df)[2:end]

    ui_labels = [string("u", i) for i in 1:length(ui_columns)]
    for (ui_column, ui_label) ∈ zip(ui_columns, ui_labels)
        ui = df[!, ui_column]
        CSV.write("../figures/nsde_$(loss_func)_$(ui_label).csv", DataFrame(t = t, ui = ui))
        plot_data = @pgf Table(
            {
                x = "t",
                y = "ui",
                col_sep = "comma",
            },
            "/mnt/shared/research/wasserstein/figures/nsde_$(loss_func)_$(ui_label).csv",
        )
        plot_line = @pgf Plot(
            {
                no_marks,
                color = "gray",
            },
            plot_data
        )
        push!(axis, plot_line)
    end

    # axis
    pgfsave("../figures/nsde_$(loss_func)_trajectories.tex", axis, include_preamble=false,)
    pgfsave("../figures/nsde_$(loss_func)_trajectories.pdf", axis)
end

# plot ground truth trajectories
axis = @pgf Axis(
    {
        xlabel = "\$t\$",
        ylabel = "\$x\$",
        width = "10cm",
        height = "10cm",
        title = "Ground Truth Dynamics",
        ymin = -3,
        ymax = 15,
        xmin = 0,
        xmax = 20,
    }
)

df = CSV.read("../data/nsde_ground_truth.csv", DataFrame)
t = df[!, :t]
ui_columns = names(df)[2:end]
ui_labels = [string("u", i) for i in 1:length(ui_columns)]
for (ui_column, ui_label) ∈ zip(ui_columns, ui_labels)
    ui = df[!, ui_column]
    CSV.write("../figures/nsde_ground_truth_$(ui_label).csv", DataFrame(t = t, ui = ui))
    plot_data = @pgf Table(
        {
            x = "t",
            y = "ui",
            col_sep = "comma",
        },
        "/mnt/shared/research/wasserstein/figures/nsde_ground_truth_$(ui_label).csv",
    )
    plot_line = @pgf Plot(
        {
            no_marks,
            color = "gray",
        },
        plot_data
    )
    push!(axis, plot_line)
end

axis

pgfsave("../figures/nsde_ground_truth_trajectories.tex", axis,  include_preamble = false)
pgfsave("../figures/nsde_ground_truth_trajectories.pdf", axis, include_preamble = false)


# plot loss
for (loss_func, loss_name) ∈ zip(loss_funcs, loss_names)
    df = CSV.read("../data/losses_$(loss_func).csv", DataFrame)
    # insert a column of epoch numbers
    df.epoch = 1:size(df, 1)
    CSV.write("../figures/nsde_$(loss_func)_loss.csv", df)
    axis = @pgf Axis(
        {
            xlabel = "epoch",
            ylabel = "loss",
            width = "10cm",
            height = "10cm",
            title = "$(loss_name) Loss",
            # ymin = 0,
            # ymax = 0.5,
            # xmin = 0,
            # xmax = 20,
        }
    )
    t = @pgf Table(
        {
            x = "epoch",
            y = "losses",
            col_sep = "comma",
        },
        "/mnt/shared/research/wasserstein/figures/nsde_$(loss_func)_loss.csv",
    )
    plot_line = @pgf Plot(
        {
            no_marks,
            color = "black",
        },
        t
    )
    legend_entry = @pgf LegendEntry(loss_name)
    push!(axis, plot_line)
    push!(axis, legend_entry)
    pgfsave("../figures/nsde_$(loss_func)_loss.tex", axis, include_preamble=false,)
    pgfsave("../figures/nsde_$(loss_func)_loss.pdf", axis)
end
