from matplotlib import pyplot as plt
from util import STEPS, BETAS, METHODS, EVALUATORS, load_json


def gather_scores(seed=0):
    """Gather scores from JSON files for a specific seed."""
    scores = {method: {} for method in METHODS}
    for method in METHODS:
        for beta in BETAS:
            for step in STEPS:
                output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                scores[method][f'{beta}_{step}'] = load_json(f"{output_dir}/scores_gpt4.json")
    return scores


def plot_metrics(scores, evaluator, beta, ax, method_colors, method_markers):
    """Plot the given safety metrics."""
    for method in METHODS:
        x_vals = []
        y_vals = []
        for step in STEPS:
            key = f'{beta}_{step}'
            if key in scores[method]:
                y = scores[method][key][f'{evaluator}_safety_biased']
                x_vals.append(step)
                y_vals.append(y)
                ax.scatter(
                    step, y,
                    alpha=0.75,
                    s=90,
                    color=method_colors[method], marker=method_markers[method]
                )
        ax.plot(x_vals, y_vals, marker=method_markers[method], color=method_colors[method], label=f'{method}')

    ax.set_xlabel('Step', fontsize=14)
    ax.set_ylabel(evaluator.replace('_', ' ').title(), fontsize=12)
    ax.grid(True)
    ax.set_xticks(STEPS)
    ax.tick_params(labelsize=12)


def main():
    """Main function to gather scores and plot metrics."""
    scores = gather_scores(seed=0)
    safety_metric_titles = {
        'mdjudge': 'Adult safety score (MD-Judge)',
        'llamaguard': 'Adult safety score (LlamaGuard)'
    }

    # Colors for different methods
    method_colors = {
        'green': 'g',  # Green for 'green'
        'full': 'r'    # Red for 'full'
    }

    # Markers for different methods
    method_markers = {
        'green': 'o',  # Circle for 'green'
        'full': 's'    # Square for 'full'
    }

    num_betas = len(BETAS)
    for i, evaluator in enumerate(EVALUATORS):
        fig, axes = plt.subplots(1, num_betas, figsize=(3.5 * num_betas, 3.5), sharey=True)
        for j, beta in enumerate(BETAS):
            ax = axes[j]
            plot_metrics(scores, evaluator, beta, ax, method_colors, method_markers)
            if j == 0:
                ax.set_ylabel(f'{safety_metric_titles[evaluator]}', fontsize=13)
            else:
                ax.set_ylabel('')
            ax.set_title(r'$\beta/\lambda=${}'.format(beta), fontsize=13)

        # Create a single legend for all subplots
        handles, labels = axes[0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2, fontsize=13, frameon=True)
        plt.tight_layout()
        plt.savefig(f'output/beta_scatter_all_{evaluator}.png', bbox_inches='tight')
        plt.savefig(f'output/beta_scatter_all_{evaluator}.pdf', bbox_inches='tight')
        plt.close()


if __name__ == "__main__":
    main()
