using Statistics
include("methods/NTR.jl")
include("methods/TR.jl")
include("loader_data.jl")
include("partial_two_body_app.jl")
include("eval.jl")
include("config.jl")

target_ranks =
[2,3,4,5,6]
#[2,3,4,5,6,7,8,9,10,11,12,13,14,15,20,25,30,35,40,45,50]


function experiment_synthetic1(D; ring_input=false, r=10)
    println("The performance of proposed method on synthetic datasets.")
    println("Ring input : $ring_input True ring rank : $r")

    runtimes = []
    fit_vals = []
    dkl_vals = []
    err_vals = []
    input_size = []
    params = []

    for j = [5,10,15,25]
        J = ones(Int64,D)*j
        n_params = get_n_params(J)

        if ring_input
            R = ones(Int64,D)*r
            T = get_low_ringrank_data(R,J)
        else
            T = rand(J...)
        end

        Tr, runtime = @timed b2_decomp(deepcopy(T),newton=true,verbose=true,tmax=10)
        fit_val, dkl_val, err_val = eval(T,Tr)
        println("b2 \t #param $n_params \t runtime $runtime \t fit $fit_val \t dkl $dkl_val \t err $err_val")

        push!(input_size, J)
        append!(params, n_params)
        append!(runtimes, runtime)
        append!(fit_vals, fit_val)
        append!(dkl_vals, dkl_val)
        append!(err_vals, err_val)
    end
    results = Dict(
         "input_size" => input_size,
         "runtimes" => runtimes,
         "n_params" => params,
         "fit_vals" => fit_vals,
         "err_vals" => err_vals,
         "dkl_vals" => dkl_vals)

    if ring_input
        save_path = joinpath(results_dir,"syn1","ringD$D.jld2")
    else
        save_path = joinpath(results_dir,"syn1","randD$D.jld2")
    end
    save(save_path,results)
    println("$save_path has been saved")
end

function experiment_synthetic2(D; j=50, ring_input=false, true_ring_rank=30)
    methods = ["b2","APG","MU","MM","HALS","lraMM"]

    println("Experiments on synthetic datasets.")

    J = ones(Int64,D)*j
    D = length(J)
    if ring_input
        R = ones(Int64,D)*true_ring_rank
        T = get_low_ringrank_data(R,J)
    else
        T = rand(J...)
    end

    println("Input tensor size is fixed as $J")
    println("Ring input : $ring_input True ring rank : $true_ring_rank")

    results =
    Dict{Any,Any}( method =>
                  Dict( string( ones(Int8,D)*rnk ) => Dict() for rnk in target_ranks)
                    for method in methods)

    n_params = get_n_params(J)
    fit_vals = []
    err_vals = []
    dkl_vals = []
    RMSE_vals = []
    runtimes = []
    for cnt = 1:trial_times
        Tr, runtime = @timed b2_decomp(deepcopy(T), newton=true, verbose=false, tmax=10)
        if cnt == 1 && trial_times != 1
            continue
        end

        fit_val, err_val, dkl_val, RMSE_val = eval(T,Tr)
        push!(runtimes, runtime)
        push!(fit_vals, fit_val)
        push!(err_vals, err_val)
        push!(dkl_vals, dkl_val)
        push!(runtimes, runtime)
        push!(RMSE_vals, RMSE_val)
        println("b2 \t #param $n_params \t runtime $runtime \t fit $fit_val \t dkl $dkl_val \t err $err_val")
    end

    if ring_input
        results["true_rank"] = ones(Int64,D)*true_ring_rank
    else
        results["true_rank"] = NaN
    end

    results["input_size"] = J
    results["b2"] = Dict(
         "runtimes" => mean(runtimes),
         "runtimes_std" => std(runtimes),
         "runtimes_min" => minimum(runtimes),
         "runtimes_max" => maximum(runtimes),
         "n_params" => n_params,
         "fit_vals" => mean(fit_vals),
         "fit_vals_std" => std(fit_vals),
         "fit_vals_min" => minimum(fit_vals),
         "fit_vals_max" => maximum(fit_vals),
         "err_vals" => mean(err_vals),
         "err_vals_std" => std(err_vals),
         "err_vals_min" => minimum(err_vals),
         "err_vals_max" => maximum(err_vals),
         "dkl_vals" => mean(dkl_vals),
         "dkl_vals_std" => std(dkl_vals),
         "dkl_vals_min" => minimum(dkl_vals),
         "dkl_vals_max" => maximum(dkl_vals)
        )

    println("b2 \t runtime ",mean(runtimes), "\t runtime_std ",std(runtimes),"\t runtime_sem ",std(runtimes)/sqrt(length(runtimes)))
    println("b2 \t err ",mean(err_vals), "\t err_std ",std(err_vals),"\t err_sem ",std(err_vals)/sqrt(length(err_vals)) )

    for rnk in target_ranks
        r = ones(Int8,D)*rnk
        n_params = get_n_params_ring(r,J)
        println(" ##### r = $r n_params = $n_params ##### ")

        for method in methods
            if method == "b2"
                continue
            end

            fit_vals = []
            err_vals = []
            dkl_vals = []
            RMSE_vals = []
            runtimes = []
            for cnt = 1:trial_times
                Gr, runtime = @timed NTR(deepcopy(T),r,method=method, verbose=false)
                fit_val, err_val, dkl_val, RMSE_val = eval(T,Gr)
                push!(fit_vals, fit_val)
                push!(err_vals, err_val)
                push!(dkl_vals, dkl_val)
                push!(runtimes, runtime)
                push!(RMSE_vals, RMSE_val)
                println("$method \t runtime $runtime \t fit $fit_val \t dkl $dkl_val \t err $err_val")
            end

            results[method][string(r)]["runtimes"] = mean( runtimes )
            results[method][string(r)]["runtimes_sum"] = sum( runtimes )
            results[method][string(r)]["runtimes_std"] = std( runtimes )
            results[method][string(r)]["runtimes_max"] = maximum( runtimes )
            results[method][string(r)]["runtimes_min"] = minimum( runtimes )
            results[method][string(r)]["fit_vals"] = mean( fit_vals )
            results[method][string(r)]["fit_vals_std"] = std( fit_vals )
            results[method][string(r)]["fit_vals_max"] = maximum( fit_vals )
            results[method][string(r)]["fit_vals_min"] = minimum( fit_vals )
            results[method][string(r)]["err_vals"] = mean( err_vals )
            results[method][string(r)]["err_vals_std"] = std( err_vals )
            results[method][string(r)]["err_vals_max"] = maximum( err_vals )
            results[method][string(r)]["err_vals_min"] = minimum( err_vals )
            results[method][string(r)]["dkl_vals"] = mean( dkl_vals )
            results[method][string(r)]["dkl_vals_std"] = std( dkl_vals )
            results[method][string(r)]["dkl_vals_max"] = maximum( dkl_vals )
            results[method][string(r)]["dkl_vals_min"] = minimum( dkl_vals )

            results[method][string(r)]["n_params"] = n_params

            println("$method \t runtime ",mean(runtimes),"\t runtime_sum ",sum(runtimes),"\t runtime_sem ",std(runtimes)/sqrt(length(runtimes)))
            println("$method \t err ",mean(err_vals),"\t err_std ",std(err_vals),"\t err_sem ",std(err_vals)/sqrt(length(err_vals)))
        end
    end

    if ring_input
        save_path = joinpath(results_dir,"syn2","ringD$D"*"J"*string(j)*"r"*string(true_ring_rank)*".jld2")
    else
        save_path = joinpath(results_dir,"syn2","randD$D"*"J$j"*".jld2")
    end
    save(save_path,results)
    println("$save_path has been saved")
end


#experiment_synthetic2(7,j=10,ring_input=true,true_ring_rank=6)
#experiment_synthetic2(7,j=10,ring_input=false,true_ring_rank=6)
