#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 12 07:47:38 2025

Load results generated by ProposedAccuracy.py and generate plots.
"""


import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
# from scipy.stats import ecdf

def plot_confidence_interval(x, values, z=1.96, color1="tab:red", color2="indianred", horizontal_line_width=0.3):
    """
    Function to plot confidence intervals.
    """
    mean = np.mean(values)
    stdev = np.std(values)
    confidence_interval = z * stdev #/ np.sqrt(len(values))

    left = x - horizontal_line_width / 2
    top = mean - confidence_interval
    right = x + horizontal_line_width / 2
    bottom = mean + confidence_interval
    plt.plot([x, x], [top, bottom], color=color2)
    plt.plot([left, right], [top, top], color=color2)
    plt.plot([left, right], [bottom, bottom], color=color2)
    plt.plot(x, mean, 'o', color=color1)
    return mean, confidence_interval

def compute_ecdf(vals, num=500):
    """
    """
    min_ = np.min(vals)
    max_ = np.max(vals)
    xarr = np.hstack((min_-np.abs(min_-max_)*1e-03, np.unique(vals)))+1e-06
    # xarr = np.hstack((min_-np.abs(min_-max_)*1e-02, np.linspace(min_, max_, num)))
    ecdf = np.zeros(len(xarr), dtype=np.float32)
    for xidx, x in enumerate(xarr):
        ecdf[xidx] = len(vals[vals<=x])/len(vals)
    return ecdf, xarr


if __name__ == "__main__":
    plt.style.use("seaborn-v0_8-whitegrid")
    fsize = (9.6, 5.76)
    mpl.rcParams['figure.dpi'] = 600

    folder = "MNIST_ln"
    loadname = "MNIST_partial_ln_labelAdaDetect_LR_SVC"

    in_data = np.load(f"{folder}/{loadname}.npz")
    save_res = True
    scalar_metrics = False
    zval = 1

    name_list = ["Storey", "Quantile", "Sum", "Fisher"]
    color_list = ["color1", "color2", "color4", "color3"]
    colors_ = ["tab:orange", "tab:olive", "tab:purple", "tab:red"]

    with open(f"{folder}/{loadname}.txt", "r") as file:
        lines = file.readlines()
        Kminus1 = eval(lines[2].split(" ")[1][:-1])
        K = Kminus1 + 1
        n = eval(lines[4].split(" ")[1][:-1])
        m = eval(lines[5].split(" ")[1][:-1])
        assert lines[4].split(" ")[0][0] == "n", ""
        pi_model = lines[10].split(" ")[-1][:-1]
        if lines[10].split(" ")[0] != "pi":
            pi_model = lines[12].split(" ")[-1][:-1]
            assert lines[12].split(" ")[0] == "pi"
        else:
            assert lines[10].split(" ")[0] == "pi"
        include = list()
        for inc in lines[-1].split(", ")[:-1]:
            include.append(inc)

    # include = ["baseline", "baseline_COD", "oracle", "oracleth", "oracle_COD", "oracle_CODth", "oracle_CODth_picky", "proposed", "proposed_picky"]
    # include = ["baseline", "oracleth", "oracle_CODth", "oracle_CODth_picky", "proposed", "proposed_picky"]
    # include = ["baseline", "oracle", "oracle_CODth", "oracle_CODth_picky", "proposed", "proposed_picky"]
    # include = ["baseline", "oracleth", "oracle_CODth", "oracle_CODth_picky", "proposed", "proposed_picky"]
    # include = ["baseline", "oracle", "oracle_COD", "oracle_CODth", "proposed"]
    # include = ["baseline"]
    # include = ["baseline", "oracleth", "oracle_CODth", "proposed"]
    # include = ["baseline", "oracle", "oracle_COD", "proposed"]
    print(include)


    if "baseline" in include:
        model_scores_baselines_mean = np.mean(in_data["model_scores_baselines"], axis=0)
        bought_data_baselines_mean = np.mean(in_data["bought_data_baselines"]/m, axis=0)
        used_data_baselines_mean = np.mean(in_data["used_data_baselines"], axis=0)
        print(f'{np.mean(in_data["model_scores_baselines"]*100, axis=0)[0]:.3f} ({np.std(in_data["model_scores_baselines"]*100, axis=0)[0]:.3f})')

    if "baseline_COD" in include:
        model_scores_baselines_COD_mean = np.mean(in_data["model_scores_baselines_COD"], axis=0)
        bought_data_baselines_COD_mean = np.mean(in_data["bought_data_baselines_COD"]/m, axis=0)
        used_data_baselines_COD_mean = np.mean(in_data["used_data_baselines_COD"], axis=0)

    if "oracle" in include:
        model_scores_oracles_mean = np.mean(in_data["model_scores_oracles"], axis=0)
        bought_data_oracles_mean = np.mean(in_data["bought_data_oracles"]/m, axis=0)
        used_data_oracles_mean = np.mean(in_data["used_data_oracles"], axis=0)

    if "oracleth" in include:
        model_scores_oraclesth_mean = np.mean(in_data["model_scores_oraclesth"], axis=0)
        bought_data_oraclesth_mean = np.mean(in_data["bought_data_oraclesth"]/m, axis=0)
        used_data_oraclesth_mean = np.mean(in_data["used_data_oraclesth"], axis=0)

    if "oracle_COD" in include:
        model_scores_partoracles_mean = np.mean(in_data["model_scores_partoracles"], axis=0)
        bought_data_partoracles_mean = np.mean(in_data["bought_data_partoracles"]/m, axis=0)
        used_data_partoracles_mean = np.mean(in_data["used_data_partoracles"], axis=0)

    if "oracle_CODth" in include:
        model_scores_partoracles_th_mean = np.mean(in_data["model_scores_partoracles_th"], axis=0)
        bought_data_partoracles_th_mean = np.mean(in_data["bought_data_partoracles_th"]/m, axis=0)
        used_data_partoracles_th_mean = np.mean(in_data["used_data_partoracles_th"], axis=0)

    if "oracle_CODth_picky" in include:
        model_scores_partoracles_th_picky_mean = np.mean(in_data["model_scores_partoracles_th_picky"], axis=0)
        bought_data_partoracles_th_picky_mean = np.mean(in_data["bought_data_partoracles_th_picky"]/m, axis=0)
        used_data_partoracles_th_picky_mean = np.nanmean(in_data["used_data_partoracles_th_picky"], axis=0)
        # model_scores_partoracles_th_picky_mean = np.mean(in_data["model_scores_partoracles_th_picky"], axis=0)
        # bought_data_partoracles_th_picky_mean = np.mean(in_data["bought_data_partoracles_th_picky"]/(2*m) + (K-1), axis=0)
        # used_data_partoracles_th_picky_mean = np.mean(in_data["used_data_partoracles_th_picky"], axis=0)

    if "proposed" in include:
        model_scores_proposed_mean = np.mean(in_data["model_scores_proposed"], axis=0)
        bought_data_proposed_mean = np.mean(in_data["bought_data_proposed"]/m, axis=0)
        used_data_proposed_mean = np.mean(in_data["used_data_proposed"], axis=0)

    if "proposed_picky" in include:
        model_scores_proposed_picky_mean = np.mean(in_data["model_scores_proposed_picky"], axis=0)
        bought_data_proposed_picky_mean = np.mean(in_data["bought_data_proposed_picky"]/m, axis=0)
        used_data_proposed_picky_mean = np.mean(in_data["used_data_proposed_picky"], axis=0)
        # model_scores_proposed_picky_mean = np.mean(in_data["model_scores_proposed_picky"], axis=0)
        # bought_data_proposed_picky_mean = np.mean(in_data["bought_data_proposed_picky"]/(2*m) + (K-1), axis=0)
        # used_data_proposed_picky_mean = np.mean(in_data["used_data_proposed_picky"], axis=0)

    sims = in_data["model_scores_baselines"].shape[0]

    # =============================================================================
    # Accuracy plot
    # =============================================================================
    # alpha_idx = [True, True, True, True]
    # alpha_idx = [True, True]
    alpha_idx = np.ones(bought_data_proposed_mean.shape[1], dtype=bool)
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan", "k"]
    _ = plt.figure(figsize=fsize)
    if "baseline" in include:
        plt.plot(bought_data_baselines_mean, model_scores_baselines_mean, "o", color=colors[0], label="Fixed budget")
    if "baseline_COD" in include:
        plt.plot(bought_data_baselines_COD_mean, model_scores_baselines_COD_mean, "*", color=colors[0], label="Fixed budget (COD)")
    if "oracle" in include:
        plt.plot(bought_data_oracles_mean, model_scores_oracles_mean, "o", color=colors[2], label="Oracle")
    if "oracleth" in include:
        plt.plot(bought_data_oraclesth_mean, model_scores_oraclesth_mean, "1", color=colors[2], label="Oracle threshold")
    if "oracle_COD" in include:
        plt.plot(bought_data_partoracles_mean, model_scores_partoracles_mean, "*", color=colors[9], label="Oracle all data")
    if "oracle_CODth" in include:
        plt.plot(bought_data_partoracles_th_mean, model_scores_partoracles_th_mean, "1", color=colors[9], label="Oracle threshold all data")
    if "oracle_CODth_picky" in include:
        plt.plot(bought_data_partoracles_th_picky_mean, model_scores_partoracles_th_picky_mean, "x", color=colors[9], label="Oracle picky")
    if "proposed" in include:
        plt.plot(bought_data_proposed_mean[:, alpha_idx, 0].flatten(), model_scores_proposed_mean[:, alpha_idx, 0].flatten(), "o", color=colors_[0], label="Storey")
        plt.plot(bought_data_proposed_mean[:, alpha_idx, 1].flatten(), model_scores_proposed_mean[:, alpha_idx, 1].flatten(), "o", color=colors_[1], label="Quantile")
        plt.plot(bought_data_proposed_mean[:, alpha_idx, 2].flatten(), model_scores_proposed_mean[:, alpha_idx, 2].flatten(), "o", color=colors_[2], label="Summation")
        plt.plot(bought_data_proposed_mean[:, alpha_idx, 3].flatten(), model_scores_proposed_mean[:, alpha_idx, 3].flatten(), "o", color=colors_[3], label="Fisher")
    if "proposed_picky" in include:
        plt.plot(bought_data_proposed_picky_mean[:, alpha_idx, 0].flatten(), model_scores_proposed_picky_mean[:, alpha_idx, 0].flatten(), "x", color=colors_[0], label="Storey picky")
        plt.plot(bought_data_proposed_picky_mean[:, alpha_idx, 1].flatten(), model_scores_proposed_picky_mean[:, alpha_idx, 1].flatten(), "x", color=colors_[1], label="Quantile picky")
        plt.plot(bought_data_proposed_picky_mean[:, alpha_idx, 2].flatten(), model_scores_proposed_picky_mean[:, alpha_idx, 2].flatten(), "x", color=colors_[2], label="Summation picky")
        plt.plot(bought_data_proposed_picky_mean[:, alpha_idx, 3].flatten(), model_scores_proposed_picky_mean[:, alpha_idx, 3].flatten(), "x", color=colors_[3], label="Fisher picky")
    plt.xlabel("Budget")
    plt.ylabel("Accuracy")
    plt.legend()
    # plt.ylim(0.678, 0.705) #retinalOCT_partial_AdaDetect_SVC_SVC
    # plt.ylim(0.705, 0.73) #retinalOCT_partial_m80_AdaDetect_SVC_SVC
    # plt.ylim(0.71, 0.735)
    # plt.ylim(0.81, 0.845)
    # plt.ylim(0.85, 0.875)
    plt.xlim(9.9, 20.1)
    plt.show()

    if save_res is True:
        with open(f"{folder}/Accuracy_{loadname}.txt", "w") as file:
            if "baseline" in include:
                file.write("\\addplot[mark=square, mark options={solid}, only marks, color6]\n")
                file.write("table{%\n")
                for x, y in zip(bought_data_baselines_mean, model_scores_baselines_mean):
                    file.write(f"{x:.2f}  {y:.4f}\n")
                # file.write("};\n")
                file.write("};\\label{plot:random}\n")

            if "oracleth" in include:
                file.write("\\addplot[mark=square, mark options={solid}, only marks, color7]\n")
                file.write("table{%\n")
                for x, y in zip(bought_data_oraclesth_mean, model_scores_oraclesth_mean):
                    file.write(f"{x:.2f}  {y:.4f}\n")
                # file.write("};\n")
                file.write("};\\label{plot:oracleth}\n")

            if "oracle_CODth" in include:
                file.write("\\addplot[mark=square, mark options={solid}, only marks, color8]\n")
                file.write("table{%\n")
                for x, y in zip(bought_data_partoracles_th_mean, model_scores_partoracles_th_mean):
                    file.write(f"{x:.2f}  {y:.4f}\n")
                # file.write("};\n")
                file.write("};\\label{plot:partoracleth}\n")

            if "proposed" in include:
                marker_list = ["diamond"]*4
                for idx in range(4):
                    file.write(f"\\addplot[mark={marker_list[idx]}"+", mark options={solid},"+f" only marks, {color_list[idx]}]\n")
                    file.write("table{%\n")
                    for x, y in zip(bought_data_proposed_mean[:, alpha_idx, idx].flatten(), model_scores_proposed_mean[:, alpha_idx, idx].flatten()):
                        file.write(f"{x:.2f}  {y:.4f}\n")
                    # file.write("};\n\\addlegendentry{"+f"{name_list[idx]} ($m =$ " + f"{m:d}" + ")}\n\n")
                    file.write("};\\label{plot:"+f"{name_list[idx]}"+"}\n")
                    # file.write("};\n")

    # =============================================================================
    # Accuracy plot (with confidence intervals)
    # =============================================================================
    fig = plt.figure(figsize=fsize)
    for bd, ms in zip(bought_data_proposed_mean[:, :, 0].flatten(), in_data["model_scores_proposed"][:, :, :, 0].reshape((sims, -1)).T):
        plot_confidence_interval(bd, ms, zval, colors_[0], "indianred")
    # for bd, ms in zip(bought_data_proposed_mean[:, :, 1].flatten(), in_data["model_scores_proposed"][:, :, :, 1].reshape((sims, -1)).T):
    #     plot_confidence_interval(bd, ms, zval, colors_[1], "yellowgreen")
    # for bd, ms in zip(bought_data_proposed_mean[:, :, 2].flatten(), in_data["model_scores_proposed"][:, :, :, 2].reshape((sims, -1)).T):
    #     plot_confidence_interval(bd, ms, zval, colors_[2], "mediumpurple")
    # for bd, ms in zip(bought_data_proposed_mean[:, :, 3].flatten(), in_data["model_scores_proposed"][:, :, :, 3].reshape((sims, -1)).T):
    #     plot_confidence_interval(bd, ms, zval, colors_[3], "wheat")
    for bd, ms in zip(bought_data_partoracles_th_mean, in_data["model_scores_partoracles_th"].T):
        plot_confidence_interval(bd, ms, zval, "tab:cyan", "paleturquoise")
    # for bd, ms in zip(bought_data_partoracles_th_mean, in_data["model_scores_oraclesth"].T):
    #     plot_confidence_interval(bd, ms, zval, "tab:green", "seagreen")
    for rho in range(K):
        plot_confidence_interval(bought_data_baselines_mean[rho+1], in_data["model_scores_baselines"][:, rho+1], zval, "tab:blue", "cornflowerblue")
    plt.show()

    # =============================================================================
    # Accuracy : differenct wrt. baseline (compute ecdfs)
    # =============================================================================
    rho_idx_arr = np.arange(0, 11, 1)
    diff_dict = dict()
    for idx, rho_idx_ in enumerate(rho_idx_arr):
        diff_dict[f"{rho_idx_}"] = dict()
        pi_th_idx_arr = np.arange(model_scores_proposed_mean.shape[0])
        alpha_idx_arr = np.arange(model_scores_proposed_mean.shape[1])
        pi_th_idx_alpha_idx = np.stack(np.meshgrid(pi_th_idx_arr, alpha_idx_arr, indexing="ij"), axis=-1).reshape((-1, 2))
        pi_th_idx_alpha_idx_max = pi_th_idx_alpha_idx[np.argmin(np.abs(bought_data_baselines_mean[rho_idx_+1]-bought_data_proposed_mean[:, :, 0]))]
        pi_th_idx, alpha_idx = pi_th_idx_alpha_idx_max[0], pi_th_idx_alpha_idx_max[1]
    
        pi_th_idx_arr_oracle = np.arange(model_scores_oraclesth_mean.shape[0])
        pi_th_idx_oracle = np.argmin(np.abs(bought_data_baselines_mean[rho_idx_+1]-bought_data_oraclesth_mean))
    
        for idx2 in range(4):
            diff = in_data["model_scores_proposed"][:, pi_th_idx, alpha_idx, idx2] - in_data["model_scores_baselines"][:, rho_idx_+1]
            diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"] = dict()
            ecdf, xarr = compute_ecdf(diff)
            diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"]["ecdf"] = ecdf
            diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"]["xarr"] = xarr
            diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"]["diff"] = diff

        diff = in_data["model_scores_partoracles_th"][:, pi_th_idx_oracle] - in_data["model_scores_baselines"][:, rho_idx_+1]
        diff_dict[f"{rho_idx_}"]["partoracles_th"] = dict()
        ecdf, xarr = compute_ecdf(diff)
        diff_dict[f"{rho_idx_}"]["partoracles_th"]["ecdf"] = ecdf
        diff_dict[f"{rho_idx_}"]["partoracles_th"]["xarr"] = xarr
        diff_dict[f"{rho_idx_}"]["partoracles_th"]["diff"] = diff

        diff = in_data["model_scores_oraclesth"][:, pi_th_idx_oracle] - in_data["model_scores_baselines"][:, rho_idx_+1]
        diff_dict[f"{rho_idx_}"]["oraclesth"] = dict()
        ecdf, xarr = compute_ecdf(diff)
        diff_dict[f"{rho_idx_}"]["oraclesth"]["ecdf"] = ecdf
        diff_dict[f"{rho_idx_}"]["oraclesth"]["xarr"] = xarr
        diff_dict[f"{rho_idx_}"]["oraclesth"]["diff"] = diff

    # =============================================================================
    # Accuracy : differenct wrt. baseline (CDF)
    # =============================================================================
    rho_idx = 6

    # positive_prob = np.zeros(6, dtype=np.float32)
    # mean_diff = np.zeros(6, dtype=np.float32)
    # diff = diff_dict[f"{rho_idx}"]["oraclesth"]["diff"]
    # positive_prob[0] = len(diff[diff > 0])/len(diff)
    # mean_diff[0] = np.mean(diff)
    # diff = diff_dict[f"{rho_idx}"]["partoracles_th"]["diff"]
    # positive_prob[1] = len(diff[diff > 0])/len(diff)
    # mean_diff[1] = np.mean(diff)
    # for idx2 in range(4):
    #     diff = diff_dict[f"{rho_idx}"][f"{name_list[idx2]}"]["diff"]
    #     positive_prob[idx2+2] = len(diff[diff > 0])/len(diff)
    #     mean_diff[idx2+2] = np.mean(diff)
    # print(positive_prob)
    # print(mean_diff*100)

    fig = plt.figure(figsize=fsize)
    plt.plot(diff_dict[f"{rho_idx}"]["oraclesth"]["xarr"], diff_dict[f"{rho_idx}"]["oraclesth"]["ecdf"], color="tab:green", label="Oracle")
    plt.plot(diff_dict[f"{rho_idx}"]["partoracles_th"]["xarr"], diff_dict[f"{rho_idx}"]["partoracles_th"]["ecdf"], color="tab:cyan", label="Oracle all data")
    for idx2 in range(4):
        plt.plot(diff_dict[f"{rho_idx}"][f"{name_list[idx2]}"]["xarr"], diff_dict[f"{rho_idx}"][f"{name_list[idx2]}"]["ecdf"], color=colors_[idx2], label=f"{name_list[idx2]}")
    plt.ylabel("Cumulative distribution function")
    plt.xlabel("Accuracy increase over baseline")
    plt.legend()
    plt.ylim(0, 1)
    plt.show()

    if save_res is True:
        with open(f"{folder}/AccuracyComp_{loadname}.txt", "w") as file:
            if "oracleth" in include:
                file.write("\\addplot[semithick, solid, color7]\n")
                file.write("table{%\n")
                for x, y in zip(diff_dict[f"{rho_idx}"]["oraclesth"]["xarr"], diff_dict[f"{rho_idx}"]["oraclesth"]["ecdf"]):
                    file.write(f"{x:.4f}  {y:.4f}\n")
                # file.write("};\n")
                file.write("};\\label{plot:oracleth_comp}\n")

            if "oracle_CODth" in include:
                file.write("\\addplot[semithick, solid, color8]\n")
                file.write("table{%\n")
                for x, y in zip(diff_dict[f"{rho_idx}"]["partoracles_th"]["xarr"], diff_dict[f"{rho_idx}"]["partoracles_th"]["ecdf"]):
                    file.write(f"{x:.4f}  {y:.4f}\n")
                # file.write("};\n")
                file.write("};\\label{plot:partoracleth_comp}\n")

            if "proposed" in include:
                for idx2 in range(4):
                    file.write(f"\\addplot[semithick, solid, {color_list[idx2]}]\n")
                    file.write("table{%\n")
                    for x, y in zip(diff_dict[f"{rho_idx}"][f"{name_list[idx2]}"]["xarr"], diff_dict[f"{rho_idx}"][f"{name_list[idx2]}"]["ecdf"]):
                        file.write(f"{x:.4f}  {y:.4f}\n")
                    # file.write("};\n\\addlegendentry{"+f"{name_list[idx]} ($m =$ " + f"{m:d}" + ")}\n\n")
                    file.write("};\\label{plot:"+f"{name_list[idx2]}"+"_comp}\n")
                    # file.write("};\n")


    # =============================================================================
    # Accuracy : differenct wrt. baseline (Positive probability)
    # =============================================================================
    rho_idx_arr = np.arange(0, 11, 1)
    positive_prob = np.zeros((len(rho_idx_arr), 6), dtype=np.float32)
    for idx, rho_idx_ in enumerate(rho_idx_arr):
        for idx2 in range(4):
            diff = diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"]["diff"]
            positive_prob[idx, idx2+2] = len(diff[diff > 0])/len(diff)
        diff = diff_dict[f"{rho_idx_}"]["partoracles_th"]["diff"]
        positive_prob[idx, 1] = len(diff[diff > 0])/len(diff)
        diff = diff_dict[f"{rho_idx_}"]["oraclesth"]["diff"]
        positive_prob[idx, 0] = len(diff[diff > 0])/len(diff)

    print(positive_prob[rho_idx])
    fig = plt.figure(figsize=fsize)
    for idx2 in range(4):
        plt.plot(rho_idx_arr, positive_prob[:, idx2+2], "o", color=colors_[idx2], label=f"{name_list[idx2]}")
    plt.plot(rho_idx_arr, positive_prob[:, 1], "o", color="tab:cyan", label="Oracle all data")
    plt.plot(rho_idx_arr, positive_prob[:, 0], "o", color="tab:green", label="Oracle")
    plt.ylabel("Probability to increase accuracy over baseline")
    plt.xlabel("Budget")
    plt.legend()
    # plt.ylim(0, 1)
    plt.show()

    # =============================================================================
    # Accuracy : differenct wrt. baseline (Average diff)
    # =============================================================================
    rho_idx_arr = np.arange(0, 11, 1)
    mean_diff = np.zeros((len(rho_idx_arr), 6), dtype=np.float32)
    # std_diff = np.zeros((len(rho_idx_arr), 6), dtype=np.float32)
    diff_all = np.zeros((len(rho_idx_arr), 6, 500), dtype=np.float32)
    for idx, rho_idx_ in enumerate(rho_idx_arr):
        for idx2 in range(4):
            diff = diff_dict[f"{rho_idx_}"][f"{name_list[idx2]}"]["diff"]
            mean_diff[idx, idx2+2] = np.mean(diff)
            # std_diff[idx, idx2+2] = np.std(diff)
            diff_all[idx, idx2+2] = diff
        diff = diff_dict[f"{rho_idx_}"]["partoracles_th"]["diff"]
        mean_diff[idx, 1] = np.mean(diff)
        # std_diff[idx, 1] = np.std(diff)
        diff_all[idx, 1] = diff
        diff = diff_dict[f"{rho_idx_}"]["oraclesth"]["diff"]
        mean_diff[idx, 0] = np.mean(diff)
        # std_diff[idx, 0] = np.std(diff)
        diff_all[idx, 0] = diff

    print(mean_diff[rho_idx])
    fig = plt.figure(figsize=fsize)
    for idx, rho_idx_ in enumerate(rho_idx_arr):
        # plot_confidence_interval(rho_idx_, diff_all[idx, 0, :], zval, "tab:green", "seagreen")
        plot_confidence_interval(rho_idx_, diff_all[idx, 1, :], zval, "tab:cyan", "paleturquoise")
        plot_confidence_interval(rho_idx_, diff_all[idx, 2, :], zval, colors_[0], "indianred")
        # plot_confidence_interval(rho_idx_, diff_all[idx, 3, :], zval, colors_[1], "yellowgreen")
        # plot_confidence_interval(rho_idx_, diff_all[idx, 4, :], zval, colors_[2], "mediumpurple")
        # plot_confidence_interval(rho_idx_, diff_all[idx, 5, :], zval, colors_[3], "wheat")
    plt.ylabel("Accuracy increase over baseline")
    plt.xlabel("Budget")
    # plt.legend()
    plt.show()

    fig = plt.figure(figsize=fsize)
    plt.plot(rho_idx_arr, mean_diff[:, 0], "o", color="tab:green", label="Oracle")
    plt.plot(rho_idx_arr, mean_diff[:, 1], "o", color="tab:cyan", label="Oracle all data")
    for idx2 in range(4):
        plt.plot(rho_idx_arr, mean_diff[:, idx2+2], "o", color=colors_[idx2], label=f"{name_list[idx2]}")
    plt.ylabel("Accuracy increase over baseline")
    plt.xlabel("Budget")
    plt.legend()
    plt.show()

    # =============================================================================
    # Boxplots
    # =============================================================================
    boxplot_list = [in_data["model_scores_partoracles_th"][:, pi_th_idx_oracle], in_data["model_scores_baselines"][:, rho_idx+1],
                    in_data["model_scores_proposed"][:, pi_th_idx, alpha_idx, 0], in_data["model_scores_proposed"][:, pi_th_idx, alpha_idx, 1],
                    in_data["model_scores_proposed"][:, pi_th_idx, alpha_idx, 3], in_data["model_scores_proposed"][:, pi_th_idx, alpha_idx, 2]]
    fig = plt.figure(figsize=fsize)
    plt.boxplot(boxplot_list, showfliers=False, showmeans=True)
    plt.show()


    # =============================================================================
    # Scalar metrics
    # =============================================================================
    # if scalar_metrics is True:
    #     # mean of max
    #     mean_max_scores_baseline = np.mean(np.max(in_data["model_scores_baselines"], axis=1))
    #     mean_max_scores_oraclesth = np.mean(np.max(in_data["model_scores_oraclesth"], axis=1))
    #     mean_max_scores_partoracles_th_picky = np.mean(np.max(in_data["model_scores_partoracles_th_picky"], axis=1))
    #     mean_max_scores_proposed_picky = np.mean(np.max(in_data["model_scores_proposed_picky"], axis=(1, 2)), axis=0)
    #     print("Mean of max", mean_max_scores_baseline, mean_max_scores_proposed_picky, mean_max_scores_partoracles_th_picky, mean_max_scores_oraclesth)

    #     # max of mean
    #     max_mean_scores_baseline = np.max(np.mean(in_data["model_scores_baselines"], axis=1))
    #     max_mean_scores_oraclesth = np.max(np.mean(in_data["model_scores_oraclesth"], axis=1))
    #     max_mean_scores_partoracles_th_picky = np.max(np.mean(in_data["model_scores_partoracles_th_picky"], axis=1))
    #     max_mean_scores_proposed_picky = np.max(np.mean(in_data["model_scores_proposed_picky"], axis=(1, 2)), axis=0)
    #     print("Max of mean", max_mean_scores_baseline, max_mean_scores_proposed_picky, max_mean_scores_partoracles_th_picky, max_mean_scores_oraclesth)

    #     # min of max
    #     min_max_scores_baseline = np.min(np.max(in_data["model_scores_baselines"], axis=1))
    #     min_max_scores_oraclesth = np.min(np.max(in_data["model_scores_oraclesth"], axis=1))
    #     min_max_scores_partoracles_th_picky = np.min(np.max(in_data["model_scores_partoracles_th_picky"], axis=1))
    #     min_max_scores_proposed_picky = np.min(np.max(in_data["model_scores_proposed_picky"], axis=(1, 2)), axis=0)
    #     print("Min of max", min_max_scores_baseline, min_max_scores_proposed_picky, min_max_scores_partoracles_th_picky, min_max_scores_oraclesth)

    #     def mean_of_avg(in_scores, in_bought, min_=None, proposed=False):
    #         budget_thresholds = np.arange(11, 20, 1)
    #         sims = in_scores.shape[0]
    #         if proposed is True:
    #             val = np.zeros((sims, 4), dtype=np.float32)
    #         else:
    #             val = np.zeros(sims, dtype=np.float32)
    #         for sim in range(sims):
    #             if proposed is True:
    #                 ms = in_scores[sim].reshape(-1, 4)
    #                 bd = in_bought[sim].reshape(-1, 4)/m
    #             else:
    #                 ms = in_scores[sim]
    #                 bd = in_bought[sim]/m
    #             for bth in budget_thresholds:      
    #                 if proposed is True:
    #                     for idx in range(4):
    #                         if len(bd[:, idx][bd[:, idx] <= bth]) > 0:
    #                             val[sim, idx] += np.max(ms[:, idx][bd[:, idx] <= bth], axis=0)
    #                         else:
    #                             val[sim, idx] += min_[sim]
    #                 else:
    #                     if len(bd[bd <= bth]) > 0:
    #                         val[sim] += np.max(ms[bd <= bth])
    #                     else:
    #                         val[sim] += min_[sim]
    #             val[sim] = val[sim]/len(budget_thresholds)
    #         return np.mean(val, axis=0)

    #     def avg_of_mean(in_scores, in_bought, min_=None, proposed=False):
    #         budget_thresholds = np.arange(11, 20, 1)
    #         if proposed is True:
    #             val = np.zeros(4, dtype=np.float32)
    #             ms = np.mean(in_scores, axis=0).reshape(-1, 4)
    #             bd = np.mean(in_bought/m, axis=0).reshape(-1, 4)
    #         else:
    #             val = 0
    #             ms = np.mean(in_scores, axis=0)
    #             bd = np.mean(in_bought/m, axis=0)
    #         for bth in budget_thresholds:
    #             if proposed is True:
    #                 for idx in range(4):
    #                     if len(bd[:, idx][bd[:, idx] <= bth]) > 0:
    #                         val[idx] += np.max(ms[:, idx][bd[:, idx] <= bth], axis=0)
    #                     else:
    #                         val[idx] += min_
    #             else:
    #                 if len(bd[bd <= bth]) > 0:
    #                     val += np.max(ms[bd <= bth])
    #                 else:
    #                     val += min_
    #         val = val/len(budget_thresholds)
    #         return val

    #     # mean of avg
    #     mean_avg_scores_baseline = mean_of_avg(in_data["model_scores_baselines"], in_data["bought_data_baselines"])
    #     mean_avg_scores_oraclesth = mean_of_avg(in_data["model_scores_oraclesth"], in_data["bought_data_oraclesth"], min_=in_data["model_scores_baselines"][:, 0])
    #     mean_avg_scores_partoracles_th_picky = mean_of_avg(in_data["model_scores_partoracles_th_picky"], in_data["bought_data_partoracles_th_picky"], min_=in_data["model_scores_baselines"][:, 0])
    #     mean_avg_scores_proposed_picky = mean_of_avg(in_data["model_scores_proposed_picky"], in_data["bought_data_proposed_picky"], min_=in_data["model_scores_baselines"][:, 0], proposed=True)
    #     print("Mean of avg", mean_avg_scores_baseline, mean_avg_scores_proposed_picky, mean_avg_scores_partoracles_th_picky, mean_avg_scores_oraclesth)

    #     # avg of mean
    #     avg_mean_scores_baseline = avg_of_mean(in_data["model_scores_baselines"], in_data["bought_data_baselines"])
    #     avg_mean_scores_oraclesth = avg_of_mean(in_data["model_scores_oraclesth"], in_data["bought_data_oraclesth"], min_=np.mean(in_data["model_scores_baselines"][:, 0]))
    #     avg_mean_scores_partoracles_th_picky = avg_of_mean(in_data["model_scores_partoracles_th_picky"], in_data["bought_data_partoracles_th_picky"], min_=np.mean(in_data["model_scores_baselines"][:, 0]))
    #     avg_mean_scores_proposed_picky = avg_of_mean(in_data["model_scores_proposed_picky"], in_data["bought_data_proposed_picky"], min_=np.mean(in_data["model_scores_baselines"][:, 0]), proposed=True)
    #     print("Avg of mean", avg_mean_scores_baseline, avg_mean_scores_proposed_picky, avg_mean_scores_partoracles_th_picky, avg_mean_scores_oraclesth)


