import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse


def load_results(save_dir: str, exp_name: str, trials: int, alg_list: list):
    results = {}
    for alg_name in alg_list:
        rewards, costs, rois = [], [], []
        for i in range(1, trials + 1):
            file = os.path.join(save_dir, f"{exp_name}_trial{i}_{alg_name}.pkl")
            with open(file, "rb") as f:
                res = pickle.load(f)
            r_seq = np.array([log.reward for log in res.rounds]).cumsum()
            c_seq = np.array([log.cost for log in res.rounds]).cumsum()
            t = np.arange(1, len(r_seq) + 1)
            avg_r = r_seq / t
            avg_c = c_seq / t
            avg_roi = np.divide(avg_r, avg_c, out=np.zeros_like(avg_r), where=avg_c > 0)
            rewards.append(avg_r)
            costs.append(avg_c*t)
            rois.append(avg_roi)
        if 'Wang' in alg_name:
            alg_name = 'Wang et al. (2023)'
        if 'OPT3' in alg_name:
            alg_name = 'Guo et al. (2025)'
        results[alg_name] = (np.array(rewards), np.array(costs), np.array(rois))
    return results


def mean_ci(data: np.ndarray, axis=0, ci=0.95):
    mean = data.mean(axis=axis)
    se = data.std(axis=axis, ddof=1) / np.sqrt(data.shape[0])
    h = 1.96 * se
    return mean, h

PLOT_STYLE = {
    "figsize": (7, 5),
    "title_fontsize": 18,
    "label_fontsize": 18,
    "legend_fontsize": 15,
    "tick_fontsize": 15,
    "line_width": 3,
    "ci_alpha": 0.3,
    "grid_alpha": 0.8,
    "roi_linewidth": 1.5,
    "roi_fontsize": 12,
}

def plot_all(args, results, ylabel_list=["Average Reward", "Spent Budget", "Average ROI"]):
    alg_list = list(results.keys())
    T = list(results.values())[0][0].shape[1]
    x = np.arange(1, T + 1)
    color_list = ['blue', 'orange', 'green', 'purple']
    n_metrics = len(ylabel_list)
    fig, axes = plt.subplots(1, n_metrics, figsize=(PLOT_STYLE["figsize"][0] * n_metrics, PLOT_STYLE["figsize"][1]))
    if n_metrics == 1:
        axes = [axes]
    for idx, ylabel in enumerate(ylabel_list):
        ax = axes[idx]
        if len(ylabel_list) == 2 and idx == 1:
            idx = 2
        for i, alg_name in enumerate(alg_list):
            color = color_list[i] if 'Guo' not in alg_name else 'teal'
            data = results[alg_name][idx]
            mean, h = mean_ci(data, axis=0)
            ax.plot(x, mean, label=alg_name, color=color, linewidth=PLOT_STYLE["line_width"])
            ax.fill_between(x, mean - h, mean + h, alpha=PLOT_STYLE["ci_alpha"], color=color)
        ax.set_xlabel("Round", fontsize=PLOT_STYLE["label_fontsize"])
        ax.set_ylabel(ylabel, fontsize=PLOT_STYLE["label_fontsize"])
        ax.tick_params(axis='both', labelsize=PLOT_STYLE["tick_fontsize"])
        if args.save_dir == "FPA_results":
            roi_threshold = 1.8
        else:
            roi_threshold = 1.3
        if 'ROI' in ylabel:
            ax.axhline(
                y=roi_threshold,
                color="red",
                linestyle="--",
                linewidth=PLOT_STYLE["roi_linewidth"],
                label=f"ROI threshold={roi_threshold}"
            )
        ax.grid(True, linestyle='-', alpha=PLOT_STYLE["grid_alpha"])
        ax.legend(fontsize=PLOT_STYLE["legend_fontsize"])
    plt.tight_layout(w_pad=2.5)
    plt.savefig(f"Figure/{args.save_dir}_{args.name}_all_metrics.pdf", dpi=1200)
    plt.show()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--name", type=str)
    parser.add_argument("--trials", type=int)
    args = parser.parse_args()
    if "FPA" in args.save_dir:
        alg_list = ['L2FOB', 'Wang']
    else:
        alg_list = ['L2FOB', 'OPT3']
    results = load_results(args.save_dir, args.name, args.trials, alg_list)
    plot_all(args, results, ylabel_list=["Averaged Reward", "Averaged ROI"])

if __name__ == "__main__":
    main()
