import matplotlib.pyplot as plt
from util import EVALUATORS, TASKS, load_json


def plot_all_combined(scores, output_dirs):
    COLORS = ['black', 'orange', 'blue', 'green', 'violet', 'red']

    """Plot combined safety and helpfulness scores for all evaluators."""
    evaluator_names = {
        'mdjudge': 'MD-Judge',
        'llamaguard': 'Llama Guard 3'
    }

    helpful_key = 'helpful'

    fig, axes = plt.subplots(1, 2, figsize=(13, 3.5))

    legend_handles = []

    for idx, evaluator in enumerate(EVALUATORS):
        ax = axes[idx]
        safety_keys = [
            k for k in scores['SFT'].keys()
            if evaluator in k
        ]
        safety_keys.sort(key=lambda x: int(x.split(':')[0].split('_')[1][1:]))  # Sort by category number
        cat_numbers = [k.split(':')[0].split('_')[1][1:] for k in safety_keys]

        for color_idx, output_dir in enumerate(output_dirs):
            safety_scores = [scores[output_dir][k] for k in safety_keys]

            line, = ax.plot(cat_numbers, safety_scores, label=output_dir, linewidth=2.5, alpha=0.8, color=COLORS[color_idx])
            helpful_score = scores[output_dir][helpful_key]
            if evaluator == EVALUATORS[0]:
                legend_handles.append((line, f'{output_dir} ({helpful_score:.2f})'))

        ax.tick_params(labelsize=12)
        ax.set_xlabel('Category number', fontsize=12)
        ax.set_ylabel(f'Safety score ({evaluator_names[evaluator]})', fontsize=12)

    # Create a shared legend box
    fig.legend([handle for handle, _ in legend_handles],
               [label for _, label in legend_handles],
               loc='upper center', fontsize=12, columnspacing=0.7,
               ncol=len(output_dirs), bbox_to_anchor=(0.5, 1.05))

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make room for the legend
    plt.savefig('all_performances_combined.png', bbox_inches='tight')
    plt.savefig('all_performances_combined.pdf', bbox_inches='tight')
    plt.show()


def main():
    scores = {}
    output_dirs = [task[1] for task in TASKS]
    for output_dir in output_dirs:
        scores[output_dir] = load_json(f'{output_dir}/scores_gpt4.json')

    plot_all_combined(scores, output_dirs)


if __name__ == "__main__":
    main()
