using Markdown


using DataFrames, CSV, Statistics


md"# Analysis of CIR model NSDE reconsntruction"

md"""
Script to combine the dataframes of different repeats generated by cix_base.py
into one data frame and analyze mean and variance
"""


md"Setting up basic preambles for testing different sample sizes $\in \left\{ 50, 100, 200, 400 \right\}$"

truth_labels = [
    "truth_base"
]

nsde_labels =[
    "nsde_width_$(replace("$i", "." => "_"))" for i in [16, 32, 64, 128]
]


loss_functions = [
    # "mean2_var",
    # "mse",
    "W2",
	# "WGAN",
	# "MMD",
	# "apprx_loglik"
]


repeats = 1:10 |> collect

path_cix_test_error(truth_label, nsde_label, loss_function, repeat) = 
    "data/cix_$(nsde_label)_$(truth_label)_$(loss_function)_repeat_$(repeat)_test_errors.csv"

# test 
truth_label, nsde_label, loss_function, repeat = truth_labels[1], nsde_labels[1], loss_functions[1], repeats[1]
path = path_cix_test_error(truth_label, nsde_label, loss_function, repeat)

begin
	collected_df = DataFrame(
	    id = [],
	    train_n_samples = [],
	    test_n_samples = [],
	    test_error = [],
	    train_error = [],
	    mse_f = [],
	    mse_σ = [],
	    loss_function = [],
		width = []
	)
	for truth_label ∈ truth_labels
	    for nsde_label ∈ nsde_labels
	        for loss_function ∈ loss_functions
	            for repeat ∈ repeats
					if !isfile(path_cix_test_error(truth_label, nsde_label, loss_function, repeat))
						continue
					end
	                tmp_df = CSV.read(path_cix_test_error(truth_label, nsde_label, loss_function, repeat), DataFrame)
					# insert a column of σ into tmp_df
					width = parse(Int, split(nsde_label, "_")[3])
					insertcols!(tmp_df, 2, :width => width)
	                append!(collected_df, tmp_df)
	            end
	        end
	    end
	end
end

collected_df


md"""## Write out the collected dataframe
Store the dataframe into a single csv file named `data/cix_widths_test_errors.csv`
"""



CSV.write("output/cix_widths_test_errors.csv", collected_df)


md"""
## Split-Apply-Combine Strategy
Use the SAC strategy to find the mean and std of `test_errors`, `train_errors`, `mse_f`, ...

1. Group the data by `loss_function`, `widths`, `test_n_samples`
2. Compute mean and std of each statistic and combine them
"""


grouped_df = groupby(collected_df, [:loss_function, :width, :test_n_samples])


combine_df = combine(grouped_df, 
                :test_error => mean => :test_error_mean,
                :test_error => std => :test_error_std,
                :train_error => mean => :train_error_mean,
                :train_error => std => :train_error_std,
                :mse_f => mean => :mse_f_mean,
                :mse_f => std => :mse_f_std,
                :mse_σ => mean => :mse_σ_mean,
                :mse_σ => std => :mse_σ_std,
                :id => length => :n_repeats
                )


sort!(combine_df, [:loss_function, :width, :test_n_samples])

# filter test_n_samples == 200
combine_df = combine_df[combine_df.test_n_samples .== 200, :]

# save the dataframe to "output/cix_widths_test_errors_mean_std.csv"
CSV.write("output/cir_widths.csv", combine_df)


# """
# [deps]
# CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
# DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
# StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

# [compat]
# CSV = "~0.10.11"
# DataFrames = "~1.6.1"
# StatsBase = "~0.34.0"
# """


