using FileIO
include("methods/NTR.jl")
include("loader_data.jl")
include("partial_two_body_app.jl")
include("eval.jl")
include("config.jl")


methods = ["APG","MU","MM","HALS","lraMM"]
function experiment_real1(datasetname,change_shape=false)
    results =
    Dict{Any,Any}(method =>
              Dict{Any,Any}(string(rnk) => Dict()
            for rnk in target_ranks_real[datasetname])
        for method in methods)

    println("===================================")
    T = load_realdata(datasetname,label=false)
    if change_shape
        T = reshape(T, shapes[datasetname] )
    end
    T = T .* 1.0

    J = size(T)
    n_params = get_n_params(J)
    println("datasetname $datasetname size $J")
    println("n_params of b2 $n_params")

    results["input_size"] = J
    results["datasetname"] = datasetname

    # proposed method
    n_params = get_n_params(J)
    fit_vals = []
    err_vals = []
    dkl_vals = []
    RMSE_vals = []
    runtimes = []
    for cnt = 1:trial_times
        Tr, runtime = @timed b2_decomp(deepcopy(T), newton=true, verbose=false, tmax=25)#, error_tol=1.0e-8)
        if cnt == 1 && trial_times != 1
            continue
        end

        fit_val, err_val, dkl_val, RMSE_val = eval(T,Tr)
        push!(runtimes, runtime)
        push!(fit_vals, fit_val)
        push!(err_vals, err_val)
        push!(dkl_vals, dkl_val)
        push!(runtimes, runtime)
        push!(RMSE_vals, RMSE_val)
        println("b2 \t #param $n_params \t runtime $runtime \t fit $fit_val \t dkl $dkl_val \t err $err_val \t RMSE $RMSE_val")
    end

    results["b2"] = Dict(
         "runtimes" => mean(runtimes),
         "runtimes_std" => std(runtimes),
         "runtimes_min" => minimum(runtimes),
         "runtimes_max" => maximum(runtimes),
         "n_params" => n_params,
         "fit_vals" => mean(fit_vals),
         "fit_vals_std" => std(fit_vals),
         "fit_vals_min" => minimum(fit_vals),
         "fit_vals_max" => maximum(fit_vals),
         "err_vals" => mean(err_vals),
         "err_vals_std" => std(err_vals),
         "err_vals_min" => minimum(err_vals),
         "err_vals_max" => maximum(err_vals),
         "dkl_vals" => mean(dkl_vals),
         "dkl_vals_std" => std(dkl_vals),
         "dkl_vals_min" => minimum(dkl_vals),
         "dkl_vals_max" => maximum(dkl_vals)
        )

    println("b2 \t runtime ",mean(runtimes), "\t runtime_std ",std(runtimes),"\t runtime_sem ",std(runtimes)/sqrt(length(runtimes)))
    println("b2 \t err ",mean(err_vals), "\t err_std ",std(err_vals),"\t err_sem ",std(err_vals)/sqrt(length(err_vals)) )

    # compared methods
    for target_rank in target_ranks_real[datasetname]
        n_params = get_n_params_ring(target_rank,J)
        println(" ##### r = $target_rank n_params = $n_params ##### ")

        for method in methods
            if method == "b2"
                continue
            end

            fit_vals = []
            err_vals = []
            dkl_vals = []
            RMSE_vals = []
            runtimes = []
            for cnt = 1:trial_times
                Gr, runtime = @timed NTR(deepcopy(T),target_rank,method=method, verbose=false)
                fit_val, err_val, dkl_val, RMSE_val = eval(T,Gr)
                push!(fit_vals, fit_val)
                push!(err_vals, err_val)
                push!(dkl_vals, dkl_val)
                push!(runtimes, runtime)
                push!(RMSE_vals, RMSE_val)
                println("$method \t runtime $runtime \t fit $fit_val \t dkl $dkl_val \t err $err_val \t RMSE $RMSE_val")
            end
            results[method][string(target_rank)]["runtimes"] = mean( runtimes )
            results[method][string(target_rank)]["runtimes_sum"] = sum( runtimes )
            results[method][string(target_rank)]["runtimes_std"] = std( runtimes )
            results[method][string(target_rank)]["runtimes_max"] = maximum( runtimes )
            results[method][string(target_rank)]["runtimes_min"] = minimum( runtimes )
            results[method][string(target_rank)]["fit_vals"] = mean( fit_vals )
            results[method][string(target_rank)]["fit_vals_std"] = std( fit_vals )
            results[method][string(target_rank)]["fit_vals_max"] = maximum( fit_vals )
            results[method][string(target_rank)]["fit_vals_min"] = minimum( fit_vals )
            results[method][string(target_rank)]["err_vals"] = mean( err_vals )
            results[method][string(target_rank)]["err_vals_std"] = std( err_vals )
            results[method][string(target_rank)]["err_vals_max"] = maximum( err_vals )
            results[method][string(target_rank)]["err_vals_min"] = minimum( err_vals )
            results[method][string(target_rank)]["dkl_vals"] = mean( dkl_vals )
            results[method][string(target_rank)]["dkl_vals_std"] = std( dkl_vals )
            results[method][string(target_rank)]["dkl_vals_max"] = maximum( dkl_vals )
            results[method][string(target_rank)]["dkl_vals_min"] = minimum( dkl_vals )
            results[method][string(target_rank)]["n_params"] = n_params

            println("$method \t runtime ",mean(runtimes),"\t runtime_sum ",sum(runtimes),"\t runtime_sem ",std(runtimes)/sqrt(length(runtimes)))
            println("$method \t err ",mean(err_vals),"\t err_std ",std(err_vals),"\t err_sem ",std(err_vals)/sqrt(length(err_vals)))
        end
    end

    save_path = joinpath(results_dir,"real1","results_$datasetname.jld2")
    save(save_path,results)
    println("$save_path has been saved")
end

datasetname = "TT_ChartRes"
experiment_real1(datasetname,true)
datasetname = "TT_Origami"
experiment_real1(datasetname,true)
