import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['font.family'] = 'Arial'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rc('font', size=14)
plt.rc('axes', titlesize=18, labelsize=16)
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('legend', fontsize=16)
plt.rc('figure', titlesize=22)

try:
    df = pd.read_csv('1010_5U5O.csv', header=0, index_col=0)
    df.reset_index(inplace=True)
    df.rename(columns={df.columns[0]: 'Complexity'}, inplace=True)

    scale_order = ['N', 'S', 'M', 'L', 'X']
    df['Complexity'] = pd.Categorical(df['Complexity'], categories=scale_order, ordered=True)
    df.sort_values('Complexity', inplace=True)

    df['Short Agent Counts'] = df['Agent Counts'].str.replace(' perception agent ', 'P') \
        .str.replace(' control agents', 'C') \
        .str.replace(' control agent', 'C') \
        .str.replace(' ', '')

    y_min = df[['E2E_Total_FLOPs(G)', 'OEL_Total_FLOPs(G)']].min().min()
    y_max = df[['E2E_Total_FLOPs(G)', 'OEL_Total_FLOPs(G)']].max().max()

    y_limit_min = y_min * 0.8
    y_limit_max = y_max * 1.2

    complexities = df['Complexity'].unique()
    num_plots = len(complexities)

    fig, axes = plt.subplots(nrows=1, ncols=num_plots, figsize=(14, 4))
    if num_plots == 1:
        axes = [axes]

    colors = {'E2E': '#0072B2', 'OEL': '#D55E00'}
    markers = {'E2E': 'o', 'OEL': 's'}

    handles, labels = None, None

    for i, level in enumerate(complexities):
        ax = axes[i]
        df_level = df[df['Complexity'] == level]

        l1, = ax.plot(df_level['Short Agent Counts'], df_level['E2E_Total_FLOPs(G)'],
                      marker=markers['E2E'], color=colors['E2E'], linestyle='-', label='E2E', linewidth=2)
        l2, = ax.plot(df_level['Short Agent Counts'], df_level['OEL_Total_FLOPs(G)'],
                      marker=markers['OEL'], color=colors['OEL'], linestyle='-', label='OEL', linewidth=2)

        ax.set_title(f'Scale: {level}', fontweight='bold')
        ax.set_yscale('log')

        ax.set_ylim(y_limit_min, y_limit_max)

        ax.margins(x=0.4)

        # ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='lightgray', alpha=0.7)
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='lightgray', alpha=0.6)

        if handles is None and labels is None:
            handles, labels = ax.get_legend_handles_labels()

        if i == 0:
            ax.set_ylabel('Total FLOPs (G) (log scale)', fontweight='bold')

    fig.legend(handles, labels, loc='lower center', ncol=2, frameon=False)
    # fig.text(0.5, 0.08, 'Agent Counts Configuration', ha='center', va='center', fontweight='bold', fontsize=16)
    fig.text(0.5, 0.16, 'Agent Counts', ha='center', va='center', fontweight='bold', fontsize=16)

    fig.suptitle(
        '10*10 5U5O',
        fontweight='bold', fontsize=20, y=0.93
    )

    # fig.tight_layout(rect=[0, 0.1, 1, 0.93])
    # plt.subplots_adjust(wspace=0.3)
    fig.tight_layout(rect=[0, 0.17, 1, 0.97])
    plt.subplots_adjust(wspace=0.25)

    plt.savefig('FLOPs_by_complexity_unified_yaxis.png', dpi=300)
    plt.savefig('FLOPs_by_complexity_unified_yaxis.pdf')
    plt.show()



except FileNotFoundError:
    print("cannot find file")
except Exception as e:
    print(f"error: {e}")
