import os
from matplotlib import pyplot as plt
from util import STEPS, BETAS, METHODS, SEEDS, EVALUATORS, load_json


def gather_scores():
    """Gather scores from JSON files for each method, beta, and step."""
    scores = {method: {} for method in METHODS}
    for method in METHODS:
        for beta in BETAS:
            for step in STEPS:
                for seed in SEEDS:
                    output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                    if f'{beta}_{step}' not in scores[method]:
                        scores[method][f'{beta}_{step}'] = load_json(f"{output_dir}/scores_gpt4.json")
                    else:
                        _score = load_json(f"{output_dir}/scores_gpt4.json")
                        for k in _score:
                            scores[method][f'{beta}_{step}'][k] += _score[k]
                for k in scores[method][f'{beta}_{step}']:
                    scores[method][f'{beta}_{step}'][k] /= float(len(SEEDS))

    for output_dir in ['SFT', 'DPO(H)']:
        scores[output_dir] = load_json(f"{output_dir}/scores_gpt4.json")

    return scores


def plot_metrics(scores, ax_labels, method, helpful_metric, safety_metrics):
    """Plot the given metrics."""
    fig, axes = plt.subplots(1, len(safety_metrics), figsize=(10, 3.5), sharey=False)

    if len(safety_metrics) == 1:
        axes = [axes]

    safety_metric_title = {
        'mdjudge': 'Adult safety score (MD-Judge)',
        'llamaguard': 'Adult safety score (Llama Guard 3)'
    }

    all_handles = []
    all_labels = []

    for ax, safety_metric in zip(axes, safety_metrics):
        ax.set_xlabel(ax_labels[0].replace('_', ' ').title(), fontsize=12)
        ax.set_ylabel(safety_metric_title[safety_metric].title(), fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.4)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(1.5)
        ax.spines['bottom'].set_linewidth(1.5)
        ax.spines['left'].set_color('gray')
        ax.spines['bottom'].set_color('gray')

        # Colors and markers for 'biased' and 'debiased'
        state_colors = {'biased': 'red', 'debiased': 'blue', 'DPO(H)': 'orange', 'SFT': 'black'}
        state_markers = {'biased': 'o', 'debiased': 's', 'DPO(H)': '^', 'SFT': 'x'}

        if helpful_metric == 'helpful':
            for key in ['SFT', 'DPO(H)']:
                show_label = safety_metric == safety_metrics[0]
                label = key if show_label else ""
                x = scores[key][f'{helpful_metric}_biased']
                y = scores[key][f'{safety_metric}_safety_biased']
                sc = ax.scatter(
                    x, y,
                    alpha=0.75,
                    s=80,
                    color=state_colors[key], marker=state_markers[key],
                    label=key
                )
                if show_label:
                    all_handles.append(sc)
                    all_labels.append(key)

        methods = METHODS if method == "both" else [method]
        for _method in methods:
            for step in STEPS:
                for beta in BETAS:
                    for state in ['biased', 'debiased']:
                        key = f'{beta}_{step}'
                        if key in scores[_method]:
                            show_label = safety_metric == safety_metrics[0] and _method == methods[0] and beta == BETAS[0] and step == STEPS[0]
                            label = {
                                'biased': 'w/o TSDI',
                                'debiased': 'w/ TSDI'
                            }[state] if show_label else ""

                            x = scores[_method][key][f'{helpful_metric}_{state}']
                            y = scores[_method][key][f'{safety_metric}_safety_{state}']
                            sc = ax.scatter(
                                x, y,
                                alpha=0.75,
                                s=80,
                                color=state_colors[state], marker=state_markers[state],
                                label=label
                            )
                            if show_label:
                                all_handles.append(sc)
                                all_labels.append(label)

        if helpful_metric == "helpful":
            # Set the style of the axes and grid
            ylim_0, ylim_1 = ax.get_ylim()
            xlim_0, xlim_1 = ax.get_xlim()
            x_pivot = scores['SFT'][f'{helpful_metric}_biased']
            y_pivot = scores['SFT'][f'{safety_metric}_safety_biased']
            ax.axvline(x=x_pivot, color='dimgray', linestyle='--', alpha=0.5)
            ax.axhline(y=y_pivot, color='dimgray', linestyle='--', alpha=0.5)
            ax.fill_between(x=[xlim_0, xlim_1], y1=ylim_0, y2=y_pivot, color='gray', alpha=0.15)
            ax.fill_betweenx(y=[ylim_0, ylim_1], x1=xlim_0, x2=x_pivot, color='gray', alpha=0.15)

    # Create a shared legend
    handles, labels = zip(*dict(zip(all_handles, all_labels)).items())
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=len(labels), fontsize=12, frameon=True)

    plt.tight_layout()
    plt.show()
    plt.savefig(f'output/beta_scatter_{method}_{safety_metrics[0]}_{safety_metrics[1]}_{helpful_metric}.png', bbox_inches='tight')
    plt.savefig(f'output/beta_scatter_{method}_{safety_metrics[0]}_{safety_metrics[1]}_{helpful_metric}.pdf', bbox_inches='tight')


def main():
    if not os.path.exists('output'):
        os.makedirs('output')

    """Main function to gather scores and plot metrics."""
    scores = gather_scores()
    for helpful_metric in ['helpful', 'rejection']:
        for method in METHODS:
            helpful_metric_title = {
                'helpful': 'Helpful win rate',
                'rejection': 'Compliance rate'
            }

            plot_metrics(
                scores,
                [helpful_metric_title[helpful_metric], 'Safety Score'],
                method,
                helpful_metric,
                EVALUATORS
            )


if __name__ == "__main__":
    main()
