import matplotlib.pyplot as plt
import math
import os

LEVEL_COLORS = {
    "Expert": "#1f77b4",  
    "Medium": "#2ca02c", 
    "Weak": "#d62728" 
}

LEVEL_LINESTYLES = {
    "Expert": "--",
    "Medium": ":",
    "Weak": "-."
}

def plot_fix_lambda(data, levels=["Expert", "Medium", "Weak"], method_key="ddgi", ref_key="our",
                         save_path='./evaluation/plot/fix_lambda.png'):
    env_names = list(data.keys())
    num_envs = len(env_names)

    fig, axes = plt.subplots(2, 4, figsize=(20, 8))
    axes = axes.flatten()

    for idx, env_name in enumerate(env_names):
        ax = axes[idx]
        env_levels = data[env_name]

        for level in levels:
            if level not in env_levels:
                continue

            entry = env_levels[level]
            x = entry["x_values"]
            y = entry[method_key]
            y_ref = entry[ref_key]

            color = LEVEL_COLORS.get(level, "#666666")
            ref_style = LEVEL_LINESTYLES.get(level, "--")

            ax.plot(x, y, label=level, color=color, linewidth=2)
            ax.axhline(y_ref, linestyle=ref_style, color=color, linewidth=1.5,
                       label=f"our-{level}" if idx == 0 else None)

        ax.set_title(env_name, fontsize=20, weight='bold')
        ax.set_xlabel("Parameter", fontsize=14)
        if idx % 3 == 0:
            ax.set_ylabel("Score", fontsize=14)
        ax.grid(True)

    if num_envs < len(axes):
        for j in range(num_envs, len(axes)):
            fig.delaxes(axes[j])

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=6, fontsize='medium')
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()

    plt.close()
    

def legend_fix_lambda(levels=["Expert", "Medium", "Weak"], save_path='./evaluation/plot/fix_lambda_legend.png'):
    from matplotlib.lines import Line2D
    fig, ax = plt.subplots(figsize=(len(levels) * 2.5, 0.8))

    handles = []
    for level in levels:
        color = LEVEL_COLORS[level]
        handles.append(Line2D([0], [0], color=color, linestyle='-', lw=2, label=level))
        handles.append(Line2D([0], [0], color=color, linestyle=LEVEL_LINESTYLES[level], lw=2, label=f"our-{level}"))

    ax.legend(handles=handles, loc='center', ncol=len(handles), frameon=True)
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

