"""
This script analyzes and plots the success rates of different models across datasets and downstream tasks.
The script reads the results from results/*.json plots the success rates for each model on each dataset.
"""

import os
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Update Matplotlib parameters for consistent font usage
plt.rcParams.update({'font.family': 'serif', 'font.serif': ['Times New Roman']})

# Constants
count = 100
datasets = ['bio', 'chem', 'cyber']
styles = ['choose', 'option', 'generate']
dataset_names = ['WMDP-Bio', 'WMDP-Chem', 'WMDP-Cyber']
models = ['base', 'rmu', 'npo-bio', 'npo-cyber']
model_names = ['Base', 'RMU', 'NPO-Bio', 'NPO-Cyber']
hatch = [
    ('rmu', 'bio'),
    ('rmu', 'cyber'),
    ('npo-bio', 'bio'),
    ('npo-cyber', 'cyber')
]

def analyze(files):
    """
    Analyze the success rates from a list of files.
    """
    success = []
    for file in files:
        if not os.path.exists(file):
            continue
        data = json.load(open(file))
        for entry in data:
            success.append(entry['success'])
    return np.mean(success)

def collect_results_for_style(style):
    """
    Collect the results dictionary for a particular downstream task style.
    """
    delta = 5 if style == 'generate' else 10
    results = {}
    for d, m in [(d, m) for d in datasets for m in models]:
        files = []
        for i in range(0, count, delta):
            files.append(f'results/{m}_wmdp-{d}_{style}_{i}_{i+delta}.json')
        results[(d, m)] = analyze(files)
    return results

def plot(results, ax, style):
    """
    Plot the results onto the given Matplotlib axis `ax`.
    Returns the handles and labels for creating a legend later.
    """
    # Extract performance scores for each model across datasets
    performance = {
        model: [results[(dataset, model)] for dataset in datasets]
        for model in models
    }
    # Set the positions and width for the bars
    x = np.arange(len(datasets))  # the label locations
    width = 0.2  # the width of the bars
    palette = sns.color_palette("Spectral", len(models))
    # Plot each model's performance on the given axis
    for i, model in enumerate(models):
        for j, dataset in enumerate(datasets):
            # Decide hatch
            if (model, dataset) in hatch:
                ax.bar(
                    x[j] + i*width - width*1.5,
                    performance[model][j],
                    width,
                    color=palette[i],
                    edgecolor='black',
                    linewidth=1.2,
                    hatch='///'
                )
            elif dataset == 'chem':
                ax.bar(
                    x[j] + i*width - width*1.5,
                    performance[model][j],
                    width,
                    label=model_names[i],  # show label only for "chem" to avoid duplicates
                    color=palette[i],
                    edgecolor='black',
                    linewidth=1.2
                )
            else:
                ax.bar(
                    x[j] + i*width - width*1.5,
                    performance[model][j],
                    width,
                    color=palette[i],
                    edgecolor='black',
                    linewidth=1.2
                )

    if style == 'generate':
        ax.set_xticks(x)
        ax.set_xticklabels(dataset_names, fontsize=18)
    else:
        ax.set_xticks([])
        ax.set_xticklabels([])
    ax.set_ylim(0, 1.15)

    # Customize y-axis ticks
    ax.set_yticks(np.arange(0, 1.1, 0.5))
    ax.set_yticklabels([f"{i:.1f}" for i in np.arange(0, 1.1, 0.5)], fontsize=18)
    ax.set_ylabel(f"Success Rate", fontsize=20, labelpad=20)

    # Add performance labels on top of each bar
    for i, model in enumerate(models):
        for j, value in enumerate(performance[model]):
            ax.text(
                x[j] + i*width - width*1.5,
                value + 0.02,
                f"{value:.2f}",
                ha='center',
                va='bottom',
                fontsize=18
            )

    # Return handles/labels to build a legend later
    handles, labels = ax.get_legend_handles_labels()
    return handles, labels

def main():
    """
    Main function to generate the plots for different downstream tasks.
    """
    fig, axes = plt.subplots(nrows=len(styles), ncols=1, figsize=(18, 12))
    # axes is now an array of Axes objects, one for each subplot row.

    all_handles = []
    all_labels = []

    for i, style in enumerate(styles):
        ax = axes[i]

        # Collect the data/results for this style
        results = collect_results_for_style(style)

        # Plot on this axis
        handles, labels = plot(results, ax, style)

        # Annotate each subplot with the style name if desired
        ax.set_title(f"Downsteam Task: {style.upper()}", fontsize=20)

        # Collect handles/labels to unify the legend
        all_handles.extend(handles)
        all_labels.extend(labels)

    # Now we create a unified legend at the top
    # 1) Deduplicate handles/labels if you don’t want repeated items
    unique_handles_labels = list(dict(zip(all_labels, all_handles)).items())
    unique_labels, unique_handles = zip(*unique_handles_labels)
    unique_labels = list(unique_labels)
    unique_handles = list(unique_handles)
    unique_labels.append('Forget Set')
    unique_handles.append(Patch(facecolor='white', edgecolor='black', hatch='///'))

    # 2) Position the legend. For example, "upper center" with several columns:
    fig.legend(
        unique_handles,
        unique_labels,
        loc='lower center',
        ncol=5,
        fontsize=18,
        frameon=True,
        # shadow=True,
        framealpha=0.5
    )

    # Make sure subplots don’t overlap the legend
    # Adjust the bottom so we don’t cut off text
    plt.tight_layout(rect=[0, 0.07, 1, 1])  # leave some space at the top for legend

    # Finally, save or show
    plt.savefig('figures/success-all.png')
    plt.savefig('figures/success-all.pdf')
    plt.show()

if __name__ == '__main__':
    main()