
using DifferentialEquations
using Flux
using SciMLSensitivity
using ProgressBars
using Random
using StatsBase: mean, var, std
using CSV, DataFrames
using Plots
using LaTeXStrings
using Printf
∑ = sum
if isdir("src")
    cd("src")
end

include("utils.jl")

# derivative of poly
function f!(du, u, p, t)
    du .= -cos.(u) .+ 0.5
end

function σ!(du, u, p, t)
    du .= p[2]
end
function f(u, p, t)
    return -cos.(u) .+ 0.5
end
function σ(u, p, t)
    if typeof(u) <: AbstractArray
        return p[2] * ones(size(u))
    end
    return p[2]
end

### Configuration ################
# initial condition
Random.seed!(1234)
u0 = rand(100) # the julia implementation of backpropagation does not handle u0 if it is very high dimensional
T = 20.0
tspan = (0.0, T)
times = 0.0:0.1:T
##################################


function loss_mse(p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob,LambaEM(),  saveat=times)
    U′ = sol.u
    return ∑([∑((uₜ′ - uₜ).^2) for (uₜ′, uₜ) in zip(U′, U)]) / length(U)
end

function loss_avg(p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, U0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob,LambaEM(),  saveat=times)
    U′ = sol.u
    𝔼u = mean.(U)
    stdu = std.(U)
    𝔼u′ = mean.(U′)
    stdu′ = std.(U′)
    return (∑(abs2,𝔼u - 𝔼u′)) / length(U)
end

function train!(loss_func)
    Random.seed!(56789)
    p′ = rand(2)
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(p′) for i in 1:1)

    # constructed a threaded version 
    opt = ADAMW(0.01)

    prog_bar = ProgressBar(1:n_repeat00)
    # store the best parameters
    # best_p′ = copy(p′)
    # best_loss = ∑(loss_func(p′) for i in 1:1)
    @show p′
    for i in prog_bar
        ∇p′ = gradient(Flux.params(p′)) do
            ∑(loss_func(p′) for i in 1:1)
        end
        Flux.update!(opt, Flux.params(p′), ∇p′)
        current_loss = ∑(loss_func(p′) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(p′)
            # best_loss = current_loss
        # end
        set_multiline_postfix(prog_bar,
        "loss=$(current_loss)\np₁=$(p′[1])\np₂=$(p′[2])")
    end
    return p′
end
function train!(loss_func, p′; n_epochs=1000)
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(p′) for i in 1:1)

    # constructed a threaded version 
    opt = ADAM(0.01)
    params = []
    prog_bar = ProgressBar(1:n_epochs)
    # store the best parameters
    # best_p′ = copy(p′)
    # best_loss = ∑(loss_func(p′) for i in 1:1)
    @show p′
    for i in prog_bar
        ∇p′ = gradient(Flux.params(p′)) do
            ∑(loss_func(p′) for i in 1:1)
        end
        if any(isnan.(∇p′[p′]))
            @error "NaN encountered in gradient, aborting"
        end
        Flux.update!(opt, Flux.params(p′), ∇p′)
        current_loss = ∑(loss_func(p′) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(p′)
            # best_loss = current_loss
        # end
        set_postfix(prog_bar, # format loss and pars using printf
        loss = @sprintf("%.3f", current_loss),
        p₁ = @sprintf("%.3f", p′[1]),
        p₂ = @sprintf("%.3f", p′[2]))
        push!(params, copy(p′))
        GC.gc()
    end
    return params
end
U0 = vcat(repeat(u0, 1)) # in the case of W₂² distance, we can generate a different number of samples
function split_U0(U0, n_sample_max)
    n_samples = length(U0)
    n_splits = ceil(Int, n_samples/n_sample_max)
    U0s = []
    for i in 1:n_splits
        start_idx = (i-1)*n_sample_max + 1
        end_idx = min(i*n_sample_max, n_samples)
        push!(U0s, U0[start_idx:end_idx])
    end
    return U0s
end
U0s = split_U0(U0, 25)

function sde_construct_solve(U0, p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, U0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob,LambaEM(),  saveat=times)
    U′ = sol.u
    return U′
end

# parameters
function evaluate(p, id; test = false)
    # redefine the loss function such that it is accessible to U in the function scope
    # function loss_avg_std(p′;var_weight = 1.0)
    #     U′s = map(u0 -> sde_construct_solve(u0, p′), U0s)
    #     U′ = [vcat([U′s[i][j]  for i in eachindex(U′s)]...) for j in 1:length(U′s[1])]
    #     𝔼u = mean.(U)
    #     stdu = std.(U)
    #     𝔼u′ = mean.(U′)
    #     stdu′ = std.(U′)
    #     return (∑(abs2,𝔼u - 𝔼u′) + ∑(abs2,stdu - stdu′) * var_weight) / length(U)
    # end
    
    function loss_avg_std(p′;var_weight = 1.0)
        U′s = map(u0 -> sde_construct_solve(u0, p′), U0s)
        U′ = [vcat([U′s[i][j]  for i in eachindex(U′s)]...) for j in 1:length(U′s[1])]
        𝔼u = mean.(U)
        varu = var.(U)
        𝔼u′ = mean.(U′)
        varu′ = var.(U′)
        return (∑(abs2,𝔼u - 𝔼u′) + ∑(abs,varu - varu′) * var_weight) / length(U)
    end
    function loss_W₂²(p′)
        U′s = map(u0 -> sde_construct_solve(u0, p′), U0s)
        U′ = [vcat([U′s[i][j]  for i in eachindex(U′s)]...) for j in 1:length(U′s[1])]
        return ∑([W₂²(uₜ′, uₜ) for (uₜ′, uₜ) in zip(U′, U)]) / length(U)
    end
    function loss_likelihood(p′)
        -ℓ̂(U, times, f, σ, p′)
    end
    function loss_mse(p′)
        U′s = map(u0 -> sde_construct_solve(u0, p′), U0s)
        U′ = [vcat([U′s[i][j]  for i in eachindex(U′s)]...) for j in 1:length(U′s[1])]

        return ∑([∑((uₜ′ - uₜ).^2) for (uₜ′, uₜ) in zip(U′, U)]) / length(U)
    end
    # specify the SDE problem
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p)
    # prob = ODEProblem(f!, u0, tspan, p)

    # solve the SDE problem
    sol = solve(prob, saveat=times)
    U = sol.u

    # create dataframe with columns t, u1, u2, ...
    df = DataFrame(t = sol.t)
    for i in 1:length(u0)
        df[!, "u$i"] = [u[i] for u in U]
    end
    CSV.write("../data/sde_ground_truth_$(id).csv", df)
    gr()
    # headless mode for gr 
    ENV["GKSwstype"] = "nul"
    # H(x) = sin(x) - p[2] * x
    # plot(-2:0.1:8, H.(-2:0.1:8), label="H(x)")
    # plot!(U[end], H.(U[end]), seriestype=:scatter, label="U(t)")

    plot(legend=false, size=(400,400))
    for i in 1:40
        plot!(sol.t, [u[i] for u in U], label="U(t)", color=:gray, alpha=0.5)
    end
    xlabel!(L"t")
    plot!(times, mean.(U), ribbon=std.(U), label="mean ± std", width=2, color=:black)
    ylabel!(L"X(t)")
    xlims!(0,T)
    title!("SDE ground truth, p₁=$(p[1]), p₂=$(p[2])")
    ylims!(-5,15)

    savefig("../figures/sde_ground_truth_$(id).pdf")

    if test
        n_repeat = 1
        repeats = ProgressBar(1:1)
        n_epochs = 10
    else
        n_repeat = 30
        repeats = ProgressBar(1:n_repeat)
        n_epochs = 1000
    end

        Random.seed!(56789)
        p_inits = [rand(2) for repeat in 1:n_repeat]
        for (p′,repeat) in zip(p_inits, repeats)
            if !isfile("../data/params_trajectory_W2_example_$(id)_$(repeat).csv")
            @info "repeat $(repeat) for W₂² loss function"
            
            p′[1] = 0.5
            params = train!(loss_W₂², p′; n_epochs=n_epochs)
                if !test
                CSV.write("../data/params_trajectory_W2_example_$(id)_$(repeat).csv",
                    DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
                end
            else
                @info "W₂² loss function already exists, skip training"
            end
        end
        if test # reset the counter of progress bar
            repeats = ProgressBar(1:1)
            n_epochs = 10
        else
            repeats = ProgressBar(1:n_repeat)
            n_epochs = 1000
        end
        ##### avg_std loss function #################
        # savefig("figures/sde_W₂_loss_minimization.pdf")
        Random.seed!(56789)
        p_inits = [rand(2) for repeat in 1:n_repeat]
        for (p′,repeat) in zip(p_inits, repeats)
    if !isfile("../data/params_trajectory_avg_std_example_$(id)_$(repeat).csv")
        @info "avg_std loss function does not exist, start training"
            @info "repeat $(repeat) for avg_std loss function"
            
            p′[1] = 0.5
            params = train!(loss_avg_std, p′; n_epochs=n_epochs)
            if !test
            CSV.write("../data/params_trajectory_avg_std_example_$(id)_$(repeat).csv", 
            DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
            end
    else
        @info "avg_std loss function already exists, skip training"
    end
        end
    
    #### MSE loss function #################
    if test # reset the counter of progress bar
        repeats = ProgressBar(1:1)
        n_epochs = 10
    else
        repeats = ProgressBar(1:n_repeat)
        n_epochs = 1000
    end
    Random.seed!(56789)
    p_inits = [rand(2) for repeat in 1:n_repeat]
    for (p′,repeat) in zip(p_inits, repeats)
        if !isfile("../data/params_trajectory_mse_example_$(id)_$(repeat).csv")
            @info "mse loss function does not exist, start training"        
            @info "repeat $(repeat) for mse loss function"
            
            p′[1] = 0.5
            params = train!(loss_mse, p′; n_epochs=n_epochs)
            if !test
            CSV.write("../data/params_trajectory_mse_example_$(id)_$(repeat).csv",
            DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
            end
        else
            @info "mse loss function already exists, skip training"
        end
    end
     

    #### loglikelihood loss function #################
    if test # reset the counter of progress bar
        repeats = ProgressBar(1:1)
        n_epochs = 10
    else
        repeats = ProgressBar(1:n_repeat)
        n_epochs = 1000
    end
    Random.seed!(56789)
    p_inits = [rand(2) for repeat in 1:n_repeat]
    for (p′,repeat) in zip(p_inits, repeats)
        if !isfile("../data/params_trajectory_loglikelihood_example_$(id)_$(repeat).csv")
            @info "repeat $(repeat) for loglikelihood loss function"
            p′[1] = 0.5
            params = train!(loss_likelihood, p′; n_epochs=n_epochs)
            if !test
            CSV.write("../data/params_trajectory_loglikelihood_example_$(id)_$(repeat).csv",
                DataFrame(p1=[p[1] for p in params], p2=[p[2] for p in params]))
                end
        else
            @info "loglikelihood loss function already exists, skip training"
        end
    end
end


# evaluate([0.5, 0.5], 1; test=true)