import pandas as pd
import matplotlib.pyplot as plt


def load_evaluated_data(file_path):
    """Load the evaluated RLHF dataset from a CSV file."""
    return pd.read_csv(file_path, index_col=0)


def determine_safety_comparison(df):
    """Determine the safety comparison for each row in the dataset."""
    df['data_fit'] = df.apply(
        lambda row: r'$s(x, y_w) \geq s(x, y_l)$'
        if row['chosen_safety_score'] >= row['rejected_safety_score']
        else r'$s(x, y_l) > s(x, y_w)$', axis=1
    )
    return df


def prepare_plot_data(df):
    """Prepare data for plotting the safety comparison counts for each category."""
    categories = sorted(df['salad_category_0'].dropna().unique().tolist())
    plot_data = []
    for cat in categories:
        cat_df = df[(df['salad_category_0'] == cat) | (df['salad_category_1'] == cat)].copy()
        count_yw_ge_yl = cat_df[cat_df['data_fit'] == r'$s(x, y_w) \geq s(x, y_l)$'].shape[0]
        count_yl_gt_yw = cat_df[cat_df['data_fit'] == r'$s(x, y_l) > s(x, y_w)$'].shape[0]
        plot_data.append((int(cat.split(':')[0][1:]), count_yw_ge_yl, count_yl_gt_yw))

    plot_df = pd.DataFrame(
        plot_data,
        columns=['Category Number', r'$s(x, y_w) \geq s(x, y_l)$', r'$s(x, y_l) > s(x, y_w)$']
    )
    return plot_df


def create_stacked_bar_plot(df):
    """Create and save a stacked bar plot of the safety comparison counts."""
    df = df.sort_values('Category Number')
    ax = df.set_index('Category Number').plot(
        kind='bar', stacked=True, figsize=(7, 3.5), color=['green', 'red']
    )

    plt.xlabel('Category Number', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.legend(fontsize=14, loc="upper right")
    ax.set_xticklabels(df['Category Number'], rotation=0, ha='center')

    plt.savefig('figures/stacked_bar_plot.pdf', bbox_inches='tight')
    plt.savefig('figures/stacked_bar_plot.png', bbox_inches='tight')
    plt.close()


def main():
    rlhf_merge_df = load_evaluated_data('rlhf_with_eval.csv')
    rlhf_merge_df = determine_safety_comparison(rlhf_merge_df)
    plot_df = prepare_plot_data(rlhf_merge_df)
    create_stacked_bar_plot(plot_df)


if __name__ == "__main__":
    main()
