import os
from matplotlib import pyplot as plt
from util import SEEDS, STEPS, BETAS, METHODS, load_json


def gather_scores(method):
    """Gather scores from JSON files for a given method."""
    scores = {}
    for beta in BETAS:
        for step in STEPS:
            scores[f'{beta}_{step}'] = {}
            for seed in SEEDS:
                output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                _score = load_json(f"{output_dir}/scores_gpt4.json")
                for k, v in _score.items():
                    if k not in scores[f'{beta}_{step}']:
                        scores[f'{beta}_{step}'][k] = 0.0
                    scores[f'{beta}_{step}'][k] += v / float(len(SEEDS))
    return scores


def plot_metrics(scores, metrics, ax_titles, method):
    """Plot the given metrics."""
    fig, axes = plt.subplots(1, len(metrics), figsize=(4.5 * len(metrics), 3.5))
    markers = ['s', 'o', 'D', '^', 'v']  # Different markers for each line

    for i, (metric, ax_label) in enumerate(zip(metrics, ax_titles)):
        ax = axes[i]
        for j, beta in enumerate(BETAS):
            steps = []
            metric_values = []
            for step in STEPS:
                key = f'{beta}_{step}'
                if key in scores:
                    steps.append(step)
                    metric_values.append(scores[key][metric])
            ax.plot(
                steps, metric_values, label=r'$\beta/\lambda=${}'.format(beta),
                marker=markers[j], linewidth=2.5, markersize=7.5
            )

        ax.grid(True)
        ax.set_ylabel(ax_label.title(), fontsize=14)
        ax.set_xlabel('Number of iterations', fontsize=14)
        ax.set_xticks(STEPS)  # Show only integer x ticks that match the number of epochs
        ax.set_xticklabels(STEPS)
        ax.set_ylim(
            {
                'false_rejection_ratio': 0.,
                'corruption_ratio': 0.,
                'helpful_biased': 0.,
                'mdjudge_safety_biased': 0.4
            }[metric],
            {
                'false_rejection_ratio': 0.4,
                'corruption_ratio': 0.13,
                'helpful_biased': 0.8,
                'mdjudge_safety_biased': 1.
            }[metric]
        )
        ax.tick_params(labelsize=12)

    # Create a single legend for all subplots
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=len(BETAS), bbox_to_anchor=(0.5, 1.1), fontsize=12)
    plt.tight_layout()
    plt.show()
    plt.savefig(f'output/{method}_iter_beta.png', bbox_inches='tight')
    plt.savefig(f'output/{method}_iter_beta.pdf', bbox_inches='tight')


def main():
    if not os.path.exists('output'):
        os.makedirs('output')

    """Main function to gather scores and plot metrics."""
    for method in METHODS:
        scores = gather_scores(method)
        plot_metrics(
            scores,
            ['helpful_biased', 'mdjudge_safety_biased'],
            ['Helpfulness win rate', 'Adult safety score (MD-Judge)'],
            method
        )


if __name__ == "__main__":
    main()
