import re
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse


def mean_ci(data: np.ndarray, axis=0, ci=0.95):
    mean = data.mean(axis=axis)
    se = data.std(axis=axis, ddof=1) / np.sqrt(data.shape[0])
    h = 1.96 * se
    return mean, h


def load_gamma_results(save_dir: str, exp_name_prefix: str, trials: int):
    results = {}
    pattern = re.compile(rf"{exp_name_prefix}_gamma_([0-9.]+)_trial(\d+)_L2FOB\.pkl")
    gamma_map = {}
    for fname in os.listdir(save_dir):
        m = pattern.match(fname)
        if m:
            gamma_val = m.group(1)
            trial_id = int(m.group(2))
            file = os.path.join(save_dir, fname)
            with open(file, "rb") as f:
                res = pickle.load(f)
            r_seq = np.array([log.reward for log in res.rounds]).cumsum()
            c_seq = np.array([log.cost for log in res.rounds]).cumsum()
            t = np.arange(1, len(r_seq) + 1)
            avg_r = r_seq / t
            avg_c = c_seq / t
            avg_roi = np.divide(avg_r, avg_c, out=np.zeros_like(avg_r), where=avg_c > 0)
            gamma_map.setdefault(gamma_val, {"rewards": [], "rois": []})
            gamma_map[gamma_val]["rewards"].append(avg_r)
            gamma_map[gamma_val]["rois"].append(avg_roi)
    for gamma_val, d in gamma_map.items():
        label = r'$\eta_\gamma = {}$'.format(gamma_val)
        results[label] = (np.array(d["rewards"]), np.array(d["rois"]))
    return results


PLOT_STYLE = {
    "figsize": (7, 5),
    "title_fontsize": 18,
    "label_fontsize": 18,
    "legend_fontsize": 15,
    "tick_fontsize": 15,
    "line_width": 2,
    "ci_alpha": 0.3,
    "grid_alpha": 0.8,
    "roi_linewidth": 1.5,
    "roi_fontsize": 12,
    "dpi": 1200,
    "wspace": 0.35,
    "hspace": 0.3
}


def plot_gamma_sweep(args, ylabel_list=["Average Reward", "Average ROI"], style=PLOT_STYLE):
    results = load_gamma_results(args.save_dir, args.name, args.trials)
    x = np.arange(1, next(iter(results.values()))[0].shape[1] + 1)
    color_list = ['orange', 'blue', 'green', 'purple', 'brown']
    fig, axes = plt.subplots(
        1, len(ylabel_list),
        figsize=(style["figsize"][0] * len(ylabel_list), style["figsize"][1])
    )
    if len(ylabel_list) == 1:
        axes = [axes]
    for idx, ylabel in enumerate(ylabel_list):
        ax = axes[idx]
        for i, (label, (reward_data, roi_data)) in enumerate(results.items()):
            data = reward_data if 'Reward' in ylabel else roi_data
            mean, h = mean_ci(data, axis=0)
            ax.plot(
                x, mean,
                label=label,
                color=color_list[i % len(color_list)],
                linewidth=style["line_width"]
            )
            ax.fill_between(x, mean - h, mean + h,
                            alpha=style["ci_alpha"],
                            color=color_list[i % len(color_list)])
        ax.set_xlabel("Round", fontsize=style["label_fontsize"])
        ax.set_ylabel(ylabel, fontsize=style["label_fontsize"])
        ax.tick_params(axis='both', labelsize=style["tick_fontsize"])
        ax.grid(True, linestyle='-', alpha=style["grid_alpha"])
        roi_threshold = 1.8
        if 'ROI' in ylabel:
            ax.axhline(
                y=roi_threshold,
                color="red",
                linestyle="--",
                linewidth=style["roi_linewidth"],
                label=f"ROI threshold"
            )
        ax.legend(fontsize=style["legend_fontsize"])
    plt.tight_layout(w_pad=2.5)
    plt.savefig(f"Figure/{args.save_dir}_gamma_sweep.pdf", dpi=style["dpi"])
    plt.show()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--name", type=str)
    parser.add_argument("--trials", type=int)
    args = parser.parse_args()
    plot_gamma_sweep(args, ylabel_list=["Averaged Reward", "Averaged ROI"])


if __name__ == "__main__":
    main()
