using FileIO
using Plots
include("plot_config.jl")
include("config.jl")

methods = ["MU","APG","MM","HALS","lraMM"]

str2vec(st) = [ parse(Int, ss) for ss in split( replace( st[2:end-1], ","=>"" ) )]

pyplot()

function get_error_bar(list_ave,list_std,list_min,list_max,yer)
    if yer == "min_max"
        yerrors = (list_ave.- list_min,
                   list_max.- list_ave)
    elseif yer == "std"
        yerrors = list_std
    elseif yer == "sem"
        yerrors = list_std ./ sqrt(trial_times)
    else
        error("yer type error")
    end
    return yerrors
end


function real1(datasetname;RMSE=false,accum=false,yer="sem",xaxis=identity)
    save_path = joinpath(results_dir,"real1","results_$datasetname.jld2")
    results = load(save_path)
    println("$save_path has been loaded")

    b2_time = results["b2"]["runtimes"]
    b2_error  = results["b2"]["err_vals"]
    b2_n_params = results["b2"]["n_params"]
    if RMSE
        b2_RMSE = results["b2"]["RMSE_vals"]
    end
    J = results["input_size"]

    if datasetname in ["light","TT_ChartRes"]
        target_ranks = target_ranks_real[datasetname]
    else
        target_ranks = collect( keys(results[methods[1]]) )
        target_ranks = str2vec.(target_ranks)
        sort!(target_ranks)
    end

    RMSE_vals_list = Dict(method => zeros( length(target_ranks) ) for method in methods)

    runtimes_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_max = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_min = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_std = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_sum = Dict(method => zeros( length(target_ranks) ) for method in methods)

    n_params_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_max = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_min = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_std = Dict(method => zeros( length(target_ranks) ) for method in methods)
    for method in methods
        for (i, target_rank) in enumerate(target_ranks)

            n_params_list[method][i] = results[method][string(target_rank)]["n_params"]

            err_vals_list[method][i] = results[method][string(target_rank)]["err_vals"]
            err_vals_list_max[method][i] = results[method][string(target_rank)]["err_vals_max"]
            err_vals_list_min[method][i] = results[method][string(target_rank)]["err_vals_min"]
            err_vals_list_std[method][i] = results[method][string(target_rank)]["err_vals_std"]

            runtimes_list[method][i] = results[method][string(target_rank)]["runtimes"]
            runtimes_list_max[method][i] = results[method][string(target_rank)]["runtimes_max"]
            runtimes_list_min[method][i] = results[method][string(target_rank)]["runtimes_min"]
            runtimes_list_std[method][i] = results[method][string(target_rank)]["runtimes_std"]
            if RMSE
                RMSE_vals_list[method][i] = results[method][string(target_rank)]["RMSE_vals"]
            end
            if accum
                runtimes_list_sum[method][i] = results[method][string(target_rank)]["runtimes_sum"]
            end
        end
    end

    plt_time  = plot()
    plt_error = plot()
    plt_RMSE  = plot()

    plt_time_sum = plot()
    plt_error_min = plot()

    title = "$datasetname $J"
    save_fig_path_time  = joinpath(figs_dir, "real1_"*datasetname*"_time.pdf")
    save_fig_path_error = joinpath(figs_dir, "real1_"*datasetname*"_error.pdf")
    save_fig_path_RMSE  = joinpath(figs_dir, "real1_"*datasetname*"_RMSE.pdf")

    save_fig_path_time_sum  = joinpath(figs_dir, "real1_"*datasetname*"_time_sum.pdf")
    save_fig_path_error_min  = joinpath(figs_dir, "real1_"*datasetname*"_error_min.pdf")

    for method in methods
        yerror = get_error_bar(runtimes_list[method],runtimes_list_std[method],
                      runtimes_list_min[method],runtimes_list_max[method],yer)

        plot!(plt_time, n_params_list[method],runtimes_list[method],
            yerrors = yerror,
            title = title,
            line = (linetype_dict[method], linewidth, linecolors[method]),
            xaxis = xaxis,
            yaxis = :log,
            xlabel = "# parameters",
            ylabel = "Running time (sec)",
            markersize = markersize,
            markershapes = markershapes_dict[method],
            markerstrokewidth = markerstrokewidth,
            markerstrokecolor = linecolors[method],
            size = img_size,
            xguidefont = fnt1,
            yguidefont = fnt1,
            xtickfont = fnt1,
            ytickfont = fnt1,
            markercolor = markercolor,
            legend = legend,
            label = method)

        yerror = get_error_bar(err_vals_list[method],err_vals_list_std[method],
                      err_vals_list_min[method],err_vals_list_max[method],yer)

        plot!(plt_error, n_params_list[method],err_vals_list[method],
            title = title,
            #yerrors = (err_vals_list[method].- err_vals_list_min[method],
                       #err_vals_list_max[method].- err_vals_list[method]),
            yerrors = yerror,
            line = (linetype_dict[method], linewidth, linecolors[method]),
            xaxis = xaxis,
            xlabel = "# parameters",
            ylabel = "Relative Error",
            markersize = markersize,
            #markercolor = markercolor,
            markershapes = markershapes_dict[method],
            markerstrokewidth = markerstrokewidth,
            markerstrokecolor = linecolors[method],
            size = img_size,
            xguidefont = fnt1,
            yguidefont = fnt1,
            xtickfont = fnt1,
            ytickfont = fnt1,
            markercolor = markercolor,
            legend = legend,
            label = method)

        if RMSE
            plot!(plt_RMSE, n_params_list[method], RMSE_vals_list[method],
                title = title,
                line = (linetype_dict[method], linewidth, linecolors[method]),
                xaxis = xaxis,
                xlabel = "# parameters",
                ylabel = "RMSE",
                markersize = markersize,
                markercolor = markercolor,
                markershapes = markershapes_dict[method],
                markerstrokewidth = markerstrokewidth,
                markerstrokecolor = linecolors[method],
                size = img_size,
                xguidefont = fnt1,
                yguidefont = fnt1,
                xtickfont = fnt1,
                ytickfont = fnt1,
                legend = legend,
                label = method)
        end

        if accum
            plot!(plt_time_sum, n_params_list[method], runtimes_list_sum[method],
                title = title,
                line = (linetype_dict[method], linewidth, linecolors[method]),
                xaxis = xaxis,
                yaxis = :log,
                xlabel = "# parameters",
                ylabel = "Running time (sec)",
                markersize = markersize,
                markershapes = markershapes_dict[method],
                markerstrokewidth = markerstrokewidth,
                markerstrokecolor = linecolors[method],
                size = img_size,
                xguidefont = fnt1,
                yguidefont = fnt1,
                xtickfont = fnt1,
                ytickfont = fnt1,
                markercolor = markercolor,
                legend = legend,
                label = method)

            plot!(plt_error_min, n_params_list[method], err_vals_list_min[method],
                title = title,
                line = (linetype_dict[method], linewidth, linecolors[method]),
                xaxis = xaxis,
                xlabel = "# parameters",
                ylabel = "Relative Error",
                markersize = markersize,
                #markercolor = markercolor,
                markershapes = markershapes_dict[method],
                markerstrokewidth = markerstrokewidth,
                markerstrokecolor = linecolors[method],
                size = img_size,
                xguidefont = fnt1,
                yguidefont = fnt1,
                xtickfont = fnt1,
                ytickfont = fnt1,
                markercolor = markercolor,
                legend = legend,
                label = method)
        end
    end
    hline!(plt_time,  [b2_time],  label="proposal", line = (linewidth, linecolors["b2"],:dot))
    hline!(plt_error, [b2_error], label="proposal", line = (linewidth, linecolors["b2"],:dot))
    vline!(plt_time,  [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))
    vline!(plt_error, [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))

    savefig(plt_time,save_fig_path_time)
    savefig(plt_error,save_fig_path_error)
    println("$save_fig_path_time has been saved")
    println("$save_fig_path_error has been saved")

    if RMSE
        hline!(plt_RMSE,  [b2_RMSE],  label="proposal", line = (linewidth, linecolors["b2"]))
        vline!(plt_RMSE,  [b2_n_params], label="", line = (linewidth, linecolors["b2"]))
        savefig(plt_RMSE, save_fig_path_RMSE)
        println("$save_fig_path_RMSE has been saved")
    end

    if accum
        hline!(plt_time_sum,  [b2_time],  label="proposal", line = (linewidth, linecolors["b2"],:dot))
        hline!(plt_error_min, [b2_error], label="proposal", line = (linewidth, linecolors["b2"],:dot))
        vline!(plt_time_sum,  [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))
        vline!(plt_error_min, [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))

        savefig(plt_time_sum,save_fig_path_time_sum)
        savefig(plt_error_min,save_fig_path_error_min)
        println("$save_fig_path_time_sum has been saved")
        println("$save_fig_path_error_min has been saved")
    end

end


datasetname = "TT_ChartRes"
real1(datasetname,RMSE=false,accum=true)
datasetname = "TT_Paint"
real1(datasetname,RMSE=false,accum=true)
datasetname = "TT_Origami"
real1(datasetname,RMSE=false,accum=true)
datasetname = "light"
real1(datasetname,RMSE=false,accum=true)
