from matplotlib import pyplot as plt
from adjustText import adjust_text
from util import STEPS, BETAS, METHODS, EVALUATORS, load_json


def gather_scores(seed=0):
    """Gather scores from JSON files for a specific seed."""
    scores = {method: {} for method in METHODS}
    for method in METHODS:
        for beta in BETAS:
            for step in STEPS:
                output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                scores[method][f'{beta}_{step}'] = load_json(f"{output_dir}/scores_gpt4.json")
    return scores


def plot_metrics(scores, metrics, ax_labels, save_suffix):
    """Plot the given metrics."""
    fig, ax = plt.subplots(1, 1, figsize=(5, 3.5), sharey=True)

    # Markers and colors for different methods
    method_markers = {
        'green': 's',  # Circle for 'green'
        'full': 'o'  # Square for 'full'
    }

    method_colors = {
        'green': 'g',  # Green for 'green'
        'full': 'r'  # Red for 'full'
    }

    for i, step in enumerate(STEPS):
        texts = []
        for method in METHODS:
            for beta in BETAS:
                key = f'{beta}_{step}'
                if key in scores[method]:
                    label = f'{method}' if beta == BETAS[0] and step == STEPS[0] else ""
                    label = label.replace('green', 'cleaned data')
                    label = label.replace('full', 'all data')
                    x = scores[method][key][metrics[0]]
                    y = scores[method][key][metrics[1]]
                    ax.scatter(
                        x, y,
                        alpha=0.75,
                        s=90,
                        color=method_colors[method], marker=method_markers[method],
                        label=label
                    )
                    # texts.append(ax.text(x, y, f'{beta}', fontsize=10, ha='right'))

        ax.set_xlabel(ax_labels[0].replace('_', ' ').title(), fontsize=12)
        if i == 0:
            ax.set_ylabel(ax_labels[1].replace('_', ' ').title(), fontsize=12)
        ax.grid(True)
        ax.set_xlim(0.175, 0.8)
        ax.set_ylim(0.45, 1.01)
        adjust_text(texts, ax=ax)

    # Create a single legend for all subplots
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels, loc='lower left', bbox_to_anchor=(0, 0), ncol=1, fontsize=11, frameon=True)
    plt.tight_layout()
    plt.show()
    plt.savefig(f'output/beta_scatter_all_{save_suffix}.png', bbox_inches='tight')
    plt.savefig(f'output/beta_scatter_all_{save_suffix}.pdf', bbox_inches='tight')


def main():
    """Main function to gather scores and plot metrics."""
    scores = gather_scores(seed=0)
    safety_metric_titles = {
        'mdjudge': 'Adult safety score (MD-Judge)',
        'llamaguard': 'Adult safety score (LlamaGuard)'
    }
    for evaluator in EVALUATORS:
        plot_metrics(
            scores,
            ['helpful_biased', f'{evaluator}_safety_biased'],
            ['Helpfulness win rate', safety_metric_titles[evaluator]],
            save_suffix=evaluator
        )


if __name__ == "__main__":
    main()
