using FileIO
using Plots
include("plot_config.jl")
include("config.jl")


str2vec(st) = [ parse(Int, ss) for ss in split( replace( st[2:end-1], ","=>"" ) )]

pyplot()

function get_error_bar(list_ave,list_std,list_min,list_max,yer)
    if yer == "min_max"
        yerrors = (list_ave.- list_min,
                   list_max.- list_ave)
    elseif yer == "std"
        yerrors = list_std
    elseif yer == "sem"
        yerrors = list_std ./ sqrt(trial_times)
    else
        error("yer type error")
    end
    return yerrors
end


function syn2(;ring_input,D,j=50,true_ring_rank=NaN,yer="sem",accum=false)
    methods = ["APG","MU","MM","HALS","lraMM"]

    if ring_input
        save_path = joinpath(results_dir,"syn2","ringD$D"*"J$j"*"r"*string(true_ring_rank)*".jld2")
        ylim_min = 0
    else
        save_path = joinpath(results_dir,"syn2","randD$D"*"J$j"*".jld2")
        ylim_min = -Inf
    end
    results = load(save_path)
    println("$save_path has been loaded")

    J = results["input_size"]
    b2_time = results["b2"]["runtimes"]
    b2_time_max = results["b2"]["runtimes_min"]
    b2_time_min = results["b2"]["runtimes_max"]
    b2_fit  = results["b2"]["fit_vals"]

    b2_error  = results["b2"]["err_vals"]
    b2_n_params = results["b2"]["n_params"]

    target_ranks = collect( keys(results[methods[2]]) )

    target_ranks = str2vec.(target_ranks)
    sort!(target_ranks)

    runtimes_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_max = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_min = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_sum = Dict(method => zeros( length(target_ranks) ) for method in methods)
    runtimes_list_std = Dict(method => zeros( length(target_ranks) ) for method in methods)
    n_params_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_max = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_min = Dict(method => zeros( length(target_ranks) ) for method in methods)
    err_vals_list_std = Dict(method => zeros( length(target_ranks) ) for method in methods)
    fit_vals_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    fit_vals_list_max = Dict(method => zeros( length(target_ranks) ) for method in methods)
    fit_vals_list_min = Dict(method => zeros( length(target_ranks) ) for method in methods)
    fit_vals_list_std = Dict(method => zeros( length(target_ranks) ) for method in methods)

    std_vals_list = Dict(method => zeros( length(target_ranks) ) for method in methods)
    for method in methods
        if method == "b2"
            continue
        end
        for (i, target_rank) in enumerate(target_ranks)
            n_params_list[method][i] = results[method][string(target_rank)]["n_params"]

            err_vals_list[method][i] = results[method][string(target_rank)]["err_vals"]
            err_vals_list_max[method][i] = results[method][string(target_rank)]["err_vals_max"]
            err_vals_list_min[method][i] = results[method][string(target_rank)]["err_vals_min"]
            err_vals_list_std[method][i] = results[method][string(target_rank)]["err_vals_std"]

            fit_vals_list[method][i] = results[method][string(target_rank)]["fit_vals"]
            fit_vals_list_max[method][i] = results[method][string(target_rank)]["fit_vals_max"]
            fit_vals_list_min[method][i] = results[method][string(target_rank)]["fit_vals_min"]
            fit_vals_list_std[method][i] = results[method][string(target_rank)]["fit_vals_std"]

            runtimes_list[method][i] = results[method][string(target_rank)]["runtimes"]
            runtimes_list_max[method][i] = results[method][string(target_rank)]["runtimes_max"]
            runtimes_list_min[method][i] = results[method][string(target_rank)]["runtimes_min"]
            runtimes_list_std[method][i] = results[method][string(target_rank)]["runtimes_std"]

            std_vals_list[method][i] = sqrt.(results[method][string(target_rank)]["fit_vals"])
            if accum
                runtimes_list_sum[method][i] = results[method][string(target_rank)]["runtimes_sum"]
            end
        end
    end

    plt_time = plot()
    plt_error = plot()
    plt_fit = plot()

    plt_time_sum = plot()
    plt_error_min = plot()

    true_ring_rank_int = true_ring_rank[1]
    if ring_input
        title = "Tensor size $J R $true_ring_rank_int"
        save_fig_path_time  = joinpath(figs_dir, "syn2_D"*string(D)*"J"*string(j)*"r"*string(true_ring_rank)*"_time.pdf")
        save_fig_path_error = joinpath(figs_dir, "syn2_D"*string(D)*"J"*string(j)*"r"*string(true_ring_rank)*"_error.pdf")
        save_fig_path_fit   = joinpath(figs_dir, "syn2_D"*string(D)*"J"*string(j)*"r"*string(true_ring_rank)*"_fit.pdf")
        if accum
            save_fig_path_time_sum  = joinpath(figs_dir, "syn2_D"*string(D)*"J"*string(j)*"r"*string(true_ring_rank)*"_time_sum.pdf")
            save_fig_path_error_min = joinpath(figs_dir, "syn2_D"*string(D)*"J"*string(j)*"r"*string(true_ring_rank)*"_error_min.pdf")
        end
    else
        title = "Tensor size $J"
        save_fig_path_time  = joinpath(figs_dir, "syn2_randD"*string(D)*"J"*string(j)*"_time.pdf")
        save_fig_path_error = joinpath(figs_dir, "syn2_randD"*string(D)*"J"*string(j)*"_error.pdf")
        save_fig_path_fit   = joinpath(figs_dir, "syn2_randD"*string(D)*"J"*string(j)*"_fit.pdf")
        if accum
            save_fig_path_time_sum  = joinpath(figs_dir, "syn2_randD"*string(D)*"J"*string(j)*"_time_sum.pdf")
            save_fig_path_error_min = joinpath(figs_dir, "syn2_randD"*string(D)*"J"*string(j)*"_error_min.pdf")
        end
    end

    #ylim_min = minimum(minimum.(values(runtimes_list)))
    #ylim_max = maximum(maximum.(values(runtimes_list)))

    ylim_min_fit = minimum(minimum.(values(fit_vals_list)))
    ylim_max_fit = maximum(maximum.(values(fit_vals_list)))
    for method in methods
        if method == "b2"
            continue
        end

        yerror = get_error_bar(runtimes_list[method],runtimes_list_std[method],
                      runtimes_list_min[method],runtimes_list_max[method],yer)

        plot!(plt_time, n_params_list[method], runtimes_list[method],
            yerrors = yerror,
            title = title,
            line = (linetype_dict[method], linewidth, linecolors[method]),
            xaxis = :log,
            yaxis = :log,
            xlabel = "# parameters",
            ylabel = "Running time (sec)",
            #ylim = (ylim_min, ylim_max),
            markersize = markersize,
            markercolor = markercolor,
            markershapes = markershapes_dict[method],
            markerstrokewidth = markerstrokewidth,
            markerstrokecolor = linecolors[method],
            size = img_size,
            xguidefont = fnt1,
            yguidefont = fnt1,
            xtickfont = fnt1,
            ytickfont = fnt1,
            legend = :topleft,
            label = method)

        yerror = get_error_bar(err_vals_list[method],err_vals_list_std[method],
                      err_vals_list_min[method],err_vals_list_max[method],yer)
        plot!(plt_error, n_params_list[method],err_vals_list[method],
            title = title,
            yerrors = yerror,
            line = (linetype_dict[method], linewidth, linecolors[method]),
            xaxis = :log,
            #yaxis = :log,
            xlabel = "# parameters",
            ylabel = "Relative Error",
            ylim = (ylim_min,Inf),
            markersize = markersize,
            markercolor = markercolor,
            markershapes = markershapes_dict[method],
            markerstrokewidth = markerstrokewidth,
            markerstrokecolor = linecolors[method],
            size = img_size,
            xguidefont = fnt1,
            yguidefont = fnt1,
            xtickfont = fnt1,
            ytickfont = fnt1,
            legend = :topleft,
            label = method)

        yerror = get_error_bar(fit_vals_list[method],fit_vals_list_std[method],
                      fit_vals_list_min[method],fit_vals_list_max[method],yer)
        plot!(plt_fit, n_params_list[method],fit_vals_list[method],
            title = title,
            yerrors =  yerror,
            line = (linetype_dict[method], linewidth, linecolors[method]),
            xaxis = :log,
            ylim = (ylim_min_fit, ylim_max_fit),
            #yaxis = :log,
            xlabel = "# parameters",
            ylabel = "Fit(%)",
            markersize = markersize,
            markercolor = markercolor,
            markershapes = markershapes_dict[method],
            markerstrokewidth = markerstrokewidth,
            markerstrokecolor = linecolors[method],
            size = img_size,
            xguidefont = fnt1,
            yguidefont = fnt1,
            xtickfont = fnt1,
            ytickfont = fnt1,
            legend = :topleft,
            label = method)

        if accum
            plot!(plt_time_sum, n_params_list[method], runtimes_list_sum[method],
                title = title,
                line = (linetype_dict[method], linewidth, linecolors[method]),
                xaxis = :log,
                yaxis = :log,
                xlabel = "# parameters",
                ylabel = "Running time (sec)",
                markersize = markersize,
                markershapes = markershapes_dict[method],
                markerstrokewidth = markerstrokewidth,
                markerstrokecolor = linecolors[method],
                size = img_size,
                xguidefont = fnt1,
                yguidefont = fnt1,
                xtickfont = fnt1,
                ytickfont = fnt1,
                markercolor = markercolor,
                legend = legend,
                label = method)

            if D==7 && ring_input == false
                yticks = ([0.50002,0.50003,0.50004,0.50005,0.50006],["0.50002","0.50003","0.50004","0.50005","0.50006"])
            elseif D==5 && ring_input == false
                yticks = ([0.50000,0.50002,0.50004,0.50006,0.50008],["0.50000","0.50002","0.50004","0.50006","0.50008"])
            else
                yticks = :auto
            end

            plot!(plt_error_min, n_params_list[method], err_vals_list_min[method],
                title = title,
                line = (linetype_dict[method], linewidth, linecolors[method]),
                xaxis = :log,
                xlabel = "# parameters",
                ylabel = "Relative Error",
                markersize = markersize,
                #markercolor = markercolor,
                markershapes = markershapes_dict[method],
                markerstrokewidth = markerstrokewidth,
                markerstrokecolor = linecolors[method],
                size = img_size,
                xguidefont = fnt1,
                yguidefont = fnt1,
                xtickfont = fnt1,
                ytickfont = fnt1,
                markercolor = markercolor,
                yticks = yticks,
                legend = legend,
                label = method)
        end
    end

    hline!(plt_time,  [b2_time],  label="proposal", line = (linewidth, linecolors["b2"],:dot))
    hline!(plt_error, [b2_error], label="proposal", line = (linewidth, linecolors["b2"],:dot))
    hline!(plt_fit,   [b2_fit],   label="proposal", line = (linewidth, linecolors["b2"],:dot))
    vline!(plt_time,  [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))
    vline!(plt_error, [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))
    vline!(plt_fit,   [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))

    savefig(plt_time,save_fig_path_time)
    savefig(plt_error,save_fig_path_error)
    savefig(plt_fit,save_fig_path_fit)
    println("$save_fig_path_time has been saved")
    println("$save_fig_path_error has been saved")
    println("$save_fig_path_fit has been saved")
    if accum
        hline!(plt_time_sum,  [b2_time],  label="proposal", line = (linewidth, linecolors["b2"],:dot))
        hline!(plt_error_min, [b2_error], label="proposal", line = (linewidth, linecolors["b2"],:dot))
        vline!(plt_time_sum,  [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))
        vline!(plt_error_min, [b2_n_params], label="", line = (linewidth, linecolors["b2"],:dot))

        savefig(plt_time_sum,save_fig_path_time_sum)
        savefig(plt_error_min,save_fig_path_error_min)
        println("$save_fig_path_time_sum has been saved")
        println("$save_fig_path_error_min has been saved")
    end

end

#yer = "sem"
#syn2(ring_input=false,D=5,j=30,true_ring_rank=15, yer=yer,accum=true)
#syn2(ring_input=true, D=5,j=30,true_ring_rank=15, yer=yer,accum=true)
#syn2(ring_input=false,D=6,j=20,true_ring_rank=10, yer=yer,accum=true)
#syn2(ring_input=true, D=6,j=20,true_ring_rank=10, yer=yer,accum=true)