import pickle
import matplotlib.pyplot as plt
import numpy as np
import re
from matplotlib.lines import Line2D
from matplotlib.ticker import FormatStrFormatter

plt.rcParams.update({
    "font.family": "serif",
    "mathtext.fontset": "dejavuserif", 
    "figure.autolayout": True         
})

STYLE_MAP = {
    1:  {"color": "#0173B2", "marker": "o"},
    8:  {"color": "#029E73", "marker": "s"},
    16: {"color": "#F20A99", "marker": "D"},
    32: {"color": "#D55E00", "marker": "v"},
    64: {"color": "#CC78BC", "marker": "^"},
}
ORACLE_STYLE = {"color": "black", "marker": "P"}

def subsample_series(x, mean, std, every):
    if every <= 1: return x, mean, std
    idx = np.arange(0, len(x), every)
    if idx[-1] != len(x) - 1: idx = np.append(idx, len(x) - 1)
    return x[idx], mean[idx], std[idx]

def as_2d_histories(history):
    arr = np.asarray(history, dtype=float)
    if arr.ndim == 1: return arr[None, :]
    return arr if arr.ndim == 2 else arr.reshape(-1, arr.shape[-1])

def get_best_config(results, tail_avg=20):
    best_per_h = {}
    for label, history in results.items():
        h_match = re.search(r"H=(\d+)", label)
        h_val = int(h_match.group(1)) if h_match else 1
        hist2d = as_2d_histories(history)
        score = float(np.mean(np.mean(hist2d[:, -tail_avg:], axis=1)))
        if (h_val not in best_per_h) or (score > best_per_h[h_val][1]):
            best_per_h[h_val] = (label, score, hist2d)
    return best_per_h


def plot_performance_comparison(std_path, exact_path, subsample_every=100):
    fontsize = 18 
    
    fig, ax = plt.subplots(figsize=(4, 3))
    
    all_data = []
    for type_name, path in [("Standard", std_path), ("Exact", exact_path)]:
        try:
            with open(path, "rb") as f: all_data.append((type_name, pickle.load(f)))
        except: print(f"Warning: {path} not found.")

    if not all_data: return
    env_name = all_data[0][1].get("env", "Unknown")
    size_str = all_data[0][1].get("size", "Unknown")

    for type_name, data in all_data:
        best_configs = get_best_config(data["results"])
        for h_val in sorted(best_configs.keys()):
            _, _, hist2d = best_configs[h_val]
            mean, std = np.mean(hist2d, axis=0), np.std(hist2d, axis=0)
            x = np.arange(len(mean))
            
            if type_name == "Exact":
                style, ls = ORACLE_STYLE, "--"
            else:
                style, ls = STYLE_MAP.get(h_val, {"color": "gray", "marker": "x"}), "-"

            xs, ms, ss = subsample_series(x, mean, std, subsample_every)
            markevery = max(1, len(xs) // 8)

            ax.plot(xs, ms, linestyle=ls, linewidth=4.0, color=style["color"],
                    marker=style["marker"], markersize=9, markevery=markevery)
            
            ax.fill_between(xs, ms - ss, ms + ss, color=style["color"], 
                            alpha=0.05, edgecolor="none", linewidth=0)

    ax.set_xlabel("Actor Iterations", fontsize=fontsize)
    ax.set_ylabel(r"Objective $\tilde{J}_{\lambda}(\theta_{k})$", fontsize=fontsize)
    
    ax.grid(linestyle="--", alpha=0.5)
    ax.tick_params(axis='both', labelsize=fontsize-2) 
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

    out_file = f"comparison_{env_name}_{size_str}.pdf"
    plt.savefig(out_file, bbox_inches="tight")
    plt.close()
    print(f"Plot saved to: {out_file}")


def save_unified_legend(output_name="legend_unified.pdf"):
    fontsize = 18 
    handles = []
    
    for h in sorted(STYLE_MAP.keys()):
        if h == 1: continue
        line = Line2D([0], [0], color=STYLE_MAP[h]["color"], marker=STYLE_MAP[h]["marker"],
                      linewidth=4.0, markersize=7, label=f"H={h}")
        handles.append(line)
        
    oracle_line = Line2D([0], [0], color=ORACLE_STYLE["color"], marker=ORACLE_STYLE["marker"],
                        linestyle='--', linewidth=4.0, markersize=7, label="Exact Critic")
    handles.append(oracle_line)

    fig_leg = plt.figure(figsize=(len(handles) * 2.2, 0.8)) 
    fig_leg.legend(
        handles=handles, 
        loc="center", 
        fontsize=fontsize, 
        frameon=False, 
        ncol=len(handles), 
        handlelength=2.5,
        handletextpad=0.4,
        columnspacing=1.5
    )
    
    fig_leg.savefig(output_name, bbox_inches="tight", transparent=True)
    plt.close()
    print(f"Legend saved to: {output_name}")

if __name__ == "__main__":
    datasets = [
        ("results_gridworld_3x4_gridsearch.pkl", "results_gridworld_3x4_EXACT.pkl"),
        ("results_synthetic_S16_gridsearch.pkl", "results_synthetic_S16_EXACT.pkl")
    ]
    
    for std, exact in datasets:
        plot_performance_comparison(std, exact, subsample_every=100)
    
    save_unified_legend()