using Markdown


using DataFrames, CSV, StatsBase

md"# Analysis of CIR model NSDE reconsntruction"

md"""
Script to combine the dataframes of different repeats generated by cix_base.py
into one data frame and analyze mean and variance
"""


md"Setting up basic preambles for testing different sample sizes $\in \left\{ 50, 100, 200, 400 \right\}$"

truth_labels = [
    "truth_n_sample_$(n_sample)" for n_sample ∈ [50, 100, 200, 400]
]

nsde_labels =[
    "nsde_base"
]

loss_functions = [
    "mean2_var",
    "mse",
    "W2",
	# "WGAN",
	"MMD",
	"apprx_loglik"
]

repeats = 1:10 |> collect

path_cix_test_error(truth_label, nsde_label, loss_function, repeat) = 
    "data/cix_$(nsde_label)_$(truth_label)_$(loss_function)_repeat_$(repeat)_test_errors.csv"

begin
	collected_df = DataFrame(
	    id = [],
	    train_n_samples = [],
	    test_n_samples = [],
	    test_error = [],
	    train_error = [],
	    mse_f = [],
	    mse_σ = [],
	    loss_function = []
	)
	for truth_label ∈ truth_labels
	    for nsde_label ∈ nsde_labels
	        for loss_function ∈ loss_functions
	            for repeat ∈ repeats
	                tmp_df = CSV.read(path_cix_test_error(truth_label, nsde_label, loss_function, repeat), DataFrame)
	                append!(collected_df, tmp_df)
	            end
	        end
	    end
	end
end

@show collected_df


md"""## Write out the collected dataframe
Store the dataframe into a single csv file named `data/cix_n_samples_test_errors.csv`
"""



CSV.write("output/cix_n_samples_test_errors.csv", collected_df)


md"""
## Split-Apply-Combine Strategy
Use the SAC strategy to find the mean and std of `test_errors`, `train_errors`, `mse_f`, ...

1. Group the data by `loss_function`, `train_n_samples`, `test_n_samples`
2. Compute mean and std of each statistic and combine them
"""


grouped_df = groupby(collected_df, [:loss_function, :train_n_samples, :test_n_samples])


combine_df = combine(grouped_df, 
                :test_error => mean => :test_error_mean,
                :test_error => std => :test_error_std,
                :train_error => mean => :train_error_mean,
                :train_error => std => :train_error_std,
                :mse_f => mean => :mse_f_mean,
                :mse_f => std => :mse_f_std,
                :mse_σ => mean => :mse_σ_mean,
                :mse_σ => std => :mse_σ_std,
                :id => length => :n_repeats
                )


sort!(combine_df, [:loss_function, :train_n_samples, :test_n_samples])


# save the dataframe to "output/cix_n_samples_test_errors_mean_std.csv"
CSV.write("output/cir_n_samples.csv", combine_df)


# """
# [deps]
# CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
# DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
# StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

# [compat]
# CSV = "~0.10.11"
# DataFrames = "~1.6.1"
# StatsBase = "~0.34.0"
# """


