import matplotlib.pyplot as plt
import math
import os

LEVEL_COLORS = {
    "Expert": "#1f77b4",  
    "Medium": "#2ca02c", 
    "Weak": "#d62728" 
}

LEVEL_LINESTYLES = {
    "Expert": "--",
    "Medium": ":",
    "Weak": "-."
}

def plot_alpha(data, levels=["Expert", "Medium", "Weak"], method_key="ddgi",
                         save_path='./evaluation/plot/alpha.png'):
    env_names = list(data.keys())
    num_envs = len(env_names)

    fig, axes = plt.subplots(2, 4, figsize=(20, 8))
    axes = axes.flatten()

    for idx, env_name in enumerate(env_names):
        ax = axes[idx]
        env_levels = data[env_name]

        for level in levels:
            if level not in env_levels:
                continue

            entry = env_levels[level]
            x = entry["x_values"]
            y = entry[method_key]

            color = LEVEL_COLORS.get(level, "#666666")
            ax.plot(x, y, label=level, color=color, linewidth=2)

        ax.axvline(0.5, color="black", linestyle="--", linewidth=1.5)
        ax.text(0.5, ax.get_ylim()[1], "initial alpha", rotation=90,
                verticalalignment='top', horizontalalignment='right', fontsize=9)

        ax.set_title(env_name, fontsize=20, weight='bold')
        ax.set_xlabel("Parameter", fontsize=14)
        if idx % 4 == 0:
            ax.set_ylabel("Score", fontsize=14)
        ax.grid(True)

    if num_envs < len(axes):
        for j in range(num_envs, len(axes)):
            fig.delaxes(axes[j])

    handles, labels = axes[0].get_legend_handles_labels()
    #fig.legend(handles, labels, loc='lower center', ncol=6, fontsize='medium')
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()

    plt.close()
    

def legend_alpha(levels=["Expert", "Medium", "Weak"], save_path='./evaluation/plot/alpha_legend.png'):
    from matplotlib.lines import Line2D
    handles = [
        Line2D([0], [0], color=LEVEL_COLORS[level], lw=2, linestyle='-', label=level)
        for level in levels
    ]

    fig, ax = plt.subplots(figsize=(len(handles) * 2.5, 0.8))
    ax.legend(handles=handles, loc='center', ncol=len(handles), frameon=True)
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()
