using CSV, DataFrames
if isdir("src")
    cd("src")
end

ps = [ [0.5, p] for p ∈ 0.0:0.1:2.0]
n_sample = 20
# data path: ../data/params_trajectory_{avg_std,W2}_example_$(label)_$(id).csv
# label ∈ collect(1:11) corresponds to eachindex(ps)
# id ∈ collect(1:10) is the repeat index

# only read the last row to obtain the final parameter
# use a dataframe to record the final parameters

params_df = DataFrame(
    p₁ = Float64[],
    p₂ = Float64[],
    p̂₁ = Float64[],
    p̂₂ = Float64[],
    repeat = Int64[],
    label = Int64[],
    loss_func = String[]
)

# optimizing W2 loss...
loss_func = "W₂²"
for (label, p) ∈ enumerate(ps)
    for id ∈ 1:n_sample
        params = CSV.read("../data/params_trajectory_W2_example_$(label)_$(id).csv", DataFrame)
        # @show 
        push!(params_df, (p[1], p[2], params.p1[end], params.p2[end], id, label, loss_func))
    end
end

# optimizing avg_std loss...
loss_func = "avg_std"
for (label, p) ∈ enumerate(ps)
    for id ∈ 1:n_sample
        params = CSV.read("../data/params_trajectory_avg_std_example_$(label)_$(id).csv", DataFrame)
        # @show 
        push!(params_df, (p[1], p[2], params.p1[end], params.p2[end], id, label, loss_func))
    end
end

loss_func = "mse"
for (label, p) ∈ enumerate(ps)
    for id ∈ 1:n_sample
        params = CSV.read("../data/params_trajectory_mse_example_$(label)_$(id).csv", DataFrame)
        # @show 
        push!(params_df, (p[1], p[2], params.p1[end], params.p2[end], id, label, loss_func))
    end
end

loss_func = "loglikelihood"
for (label, p) ∈ enumerate(ps)
    for id ∈ 1:n_sample
        params = CSV.read("../data/params_trajectory_loglikelihood_example_$(label)_$(id).csv", DataFrame)
        # @show 
        push!(params_df, (p[1], p[2], params.p1[end], params.p2[end], id, label, loss_func))
    end
end

# split according to loss_func, p₁, p₂; 
# combine by evaluating mean error between p and p̂ and rmse between p and p̂

using Statistics
function calc_error(df::SubDataFrame)
    error = mean(sqrt.((df.p₁ .- df.p̂₁).^2 .+ (df.p₂ .- df.p̂₂).^2))
    std_error = std(sqrt.((df.p₁ .- df.p̂₁).^2 .+ (df.p₂ .- df.p̂₂).^2))
    return (error=error, std_error=std_error)
end
# split by loss_func, p₁, and p₂
grouped = groupby(params_df, [:loss_func, :p₁, :p₂])
# combine by calculating mean error and rmse for each group
combined = combine(grouped, 
    :p₁ => length => :count, 
    :p̂₁ => mean => :mean_p̂₁,
    :p̂₁ => std => :std_p̂₁,
    :p̂₂ => mean => :mean_p̂₂, 
    :p̂₂ => std => :std_p̂₂,
    calc_error
)

# save to csv
CSV.write("../data/params_trajectory_analysis.csv", combined)

using Pipe
sum_df=@pipe combined |> 
        groupby(_, [:loss_func]) |> 
        combine(_, :error => mean => :mean_error,
                    :std_error => mean => :mean_std_error,)
CSV.write("../data/params_trajectory_analysis_summary.csv", sum_df)
