using Markdown


using DataFrames, CSV, StatsBase

md"# Analysis of CIR model NSDE reconsntruction"

md"""
Script to combine the dataframes of different repeats generated by example2d_base.py
into one data frame and analyze mean and variance
"""


md"Setting up basic preambles for testing different sample sizes $\in \left\{ 50, 100, 200, 400 \right\}$"

truth_labels = [
    "truth_base" 
]

nsde_labels = [
	"nsde_n_rotate_$(n_rotate)" for n_rotate ∈ collect(1:9)
]
loss_functions = [
    # "mean2_var",
    # "mse",
    # "W2",
	# "W2_rotated",
	# "sliced_W2",
	# "WGAN",
	# "MMD",
	"W2_rotated_corrected",
	# "apprx_loglik"
]

repeats = 1:10 |> collect

path_example2d_test_error(truth_label, nsde_label, loss_function, repeat) = 
    "data/example2d_$(nsde_label)_$(truth_label)_$(loss_function)_repeat_$(repeat)_losses.csv"

begin
	collected_df = DataFrame(
	)
	for truth_label ∈ truth_labels
	    for nsde_label ∈ nsde_labels
	        for loss_function ∈ loss_functions
	            for repeat ∈ repeats
					println(path_example2d_test_error(truth_label, nsde_label, loss_function, repeat))
					if !isfile(path_example2d_test_error(truth_label, nsde_label, loss_function, repeat))
						continue
					end
	                tmp_df = CSV.read(path_example2d_test_error(truth_label, nsde_label, loss_function, repeat), DataFrame)
					# only keep the last row
					tmp_df = tmp_df[end:end, :]
					# use tmp_df2 from memory usage
					# tmp_df2 = CSV.read(path_example2d_test_error_memory(truth_label, nsde_label, loss_function, repeat), DataFrame)
					# add the id and loss_function, n_rotate columns
					tmp_df.id = [repeat]
					tmp_df.loss_function = [loss_function]
					tmp_df.n_rotate = [parse(Int, split(nsde_label, "_")[end])]
	                append!(collected_df, tmp_df)
	            end
	        end
	    end
	end
end

@show collected_df


md"""## Write out the collected dataframe
Store the dataframe into a single csv file named `data/example2d_n_rotates_test_errors.csv`
"""



CSV.write("output/example2d_n_rotates_test_errors.csv", collected_df)


md"""
## Split-Apply-Combine Strategy
Use the SAC strategy to find the mean and std of `test_errors`, `train_errors`, `mse_f`, ...

1. Group the data by `loss_function`, `train_n_rotates`, `test_n_rotates`
2. Compute mean and std of each statistic and combine them
"""


grouped_df = groupby(collected_df, [:loss_function,:n_rotate])


combine_df = combine(grouped_df, 
                :mse_f => mean => :mse_f_mean,
                :mse_f => std => :mse_f_std,
                :mse_σ => mean => :mse_σ_mean,
                :mse_σ => std => :mse_σ_std,
				:memory_usage => mean => :memory_usage_mean,
				:memory_usage => std => :memory_usage_std,
				:runtime => mean => :runtime_mean,
				:runtime => std => :runtime_std,
                :id => length => :n_repeats
                )

# divide runtime by 60
combine_df.runtime_mean = combine_df.runtime_mean ./ 60
combine_df.runtime_std = combine_df.runtime_std ./ 60

sort!(combine_df, [:loss_function, :n_rotate])



# """
# [deps]
# CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
# DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
# StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

# [compat]
# CSV = "~0.10.11"
# DataFrames = "~1.6.1"
# StatsBase = "~0.34.0"
# """

function remove_underline(s)
	return replace(s, "_" => "")
end

combine_df.loss_function = remove_underline.(combine_df.loss_function)

# save the dataframe to "output/example2d_n_rotates_test_errors_mean_std.csv"
CSV.write("output/example2d_n_rotates.csv", combine_df)

