import matplotlib.pyplot as plt
import math
import os


BASELINE_COLORS = {
    "BC": "#1f77b4",       
    "DP": "#ff7f0e",       
    "DBC": "#2ca02c",      
    "DD": "#d62728",       
    "DDGI": "#9467bd"   
}

def plot_dataset(data, baselines, save_path='./evaluation/plot/dataset.png'):
    num_envs = len(data)
    cols = 4 #num_envs
    rows = math.ceil(num_envs / cols)  #1

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False)
    env_names = list(data.keys())

    for idx, env_name in enumerate(env_names):
        ax = axes[idx // cols][idx % cols]
        env_data = data[env_name]
        timesteps = env_data["num of dataset"]

        for baseline in baselines:
            if baseline not in env_data:
                continue
            means, stds = zip(*env_data[baseline])
            color = BASELINE_COLORS.get(baseline, None)
            ax.plot(timesteps, means, label=baseline, color=color)
            ax.fill_between(
                timesteps,
                [m - s for m, s in zip(means, stds)],
                [m + s for m, s in zip(means, stds)],
                alpha=0.2,
                color=color
            )

        ax.set_title(env_name, fontsize=20, weight='bold')
        ax.set_xlabel("Num of dataset", fontsize=14)
        ax.set_ylabel("Score", fontsize=14)
        ax.grid(True)

    for i in range(num_envs, rows * cols):
        fig.delaxes(axes[i // cols][i % cols])

    handles, labels = ax.get_legend_handles_labels()
    #fig.legend(handles, labels, loc='lower center', ncol=len(baselines), fontsize=12)
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()
    plt.close()
    
    
def legend_dataset(labels, save_path="./evaluation/plot/legend_dataset.png", linewidth=4, figsize=(6, 0.6)):
    from matplotlib.lines import Line2D
    fig, ax = plt.subplots(figsize=figsize)
    handles = [Line2D([0], [0], color=BASELINE_COLORS[b], lw=linewidth) for b in baselines]
    ax.legend(handles, baselines, loc='center', ncol=len(baselines), frameon=True)
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

