import glob
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import argparse

# sns.set_style('darkgrid')
sns.set_style('whitegrid')
sns.set_palette('colorblind')

def load_results(file_path, category, avg=False):
    with open(file_path, 'r') as f:
        data = json.load(f)

    data = data[category]

    # Convert data to a pandas DataFrame
    performance_data = pd.DataFrame(data).T.reset_index()
    performance_data = performance_data.rename(columns={'index': 'Dataset'})

    # Rearrange the columns
    column_names = ['Dataset', 'Weak Performance', 'WTS-Naive', 'WTS-Aux-Loss', 'Strong Performance']
    column_names = [col for col in column_names if col in performance_data.columns]
    performance_data = performance_data[column_names]

    if not avg:
        # Rearrange the rows
        dataset_names = ["sst2", "qqp", "mnli", "mnli-mm", "qnli", "rte"]
        dataset_names = [name for name in dataset_names if name in performance_data['Dataset'].values]
        performance_data = performance_data[performance_data['Dataset'].isin(dataset_names)]
        performance_data['Dataset'] = pd.Categorical(performance_data['Dataset'], categories=dataset_names, ordered=True)
        performance_data = performance_data.sort_values('Dataset')

    # Melt the DataFrame for easier plotting
    performance_data_melted = performance_data.melt(id_vars='Dataset', var_name='Performance Type', value_name='Performance')
    return performance_data_melted


# Parse the arguments
parser = argparse.ArgumentParser(description='Plot the performance data')
parser.add_argument('file_path_prefix', type=str, help='Prefix of the path to the JSON file containing the performance data')
parser.add_argument('--title', type=str, default='Weak vs Strong Performance across Datasets', help='Title of the plot')
parser.add_argument('--avg', action='store_true', help='Whether to plot the average performance')
# parser.add_argument('--category', type=str, choices=['adversarial', 'original'], default='adversarial', help='Category of the performance data')
args = parser.parse_args()

file_path_prefix = args.file_path_prefix
avg = args.avg
# category = args.category

# Get list of all files matching the prefix
file_paths = glob.glob(file_path_prefix + '*.json')
# print(file_path_prefix)
# print(file_paths)

# Create a figure with two subplots, one for each category
fig, axes = plt.subplots(2, 1, figsize=(6, 12))

for idx, category in enumerate(['adversarial', 'original']):
    # Load results for each run and concatenate them
    performance_data_melted = None
    for file_path in file_paths:
        performance_data_melted_run = load_results(file_path, category, avg=avg)
        if performance_data_melted is None:
            performance_data_melted = performance_data_melted_run
        else:
            performance_data_melted = pd.concat([performance_data_melted, performance_data_melted_run])

    # print(performance_data_melted)
    # exit()
    # Get min and max performance values
    min_performance = performance_data_melted['Performance'].min()
    max_performance = performance_data_melted['Performance'].max()

    if avg:
        strong_ceiling = performance_data_melted[(performance_data_melted['Performance Type'] == 'Strong Performance') & (performance_data_melted['Dataset'] == 'Phase 3')]['Performance'].mean()
        # Drop 'Strong Performance' from the melted DataFrame
        performance_data_melted = performance_data_melted[performance_data_melted['Performance Type'] != 'Strong Performance']

    sns.barplot(ax=axes[idx], x='Dataset', y='Performance', hue='Performance Type', data=performance_data_melted)       # errorbar='se'
    if category == 'adversarial':
        axes[idx].set_ylabel('Adversarial Robustness (%)\n(better \u2192)', fontsize=24)
    elif category == 'original':
        axes[idx].set_ylabel('Task Performance (%)\n(better \u2192)', fontsize=24)
    else:
        raise ValueError(f"Invalid category: {category}")
    
    axes[idx].set_ylim(min_performance, max_performance)
    axes[idx].tick_params(axis='y', labelsize=18)

    if avg:
        axes[idx].set_xlabel('')
        axes[idx].tick_params(axis='x', labelsize=18)
        handles, labels = axes[idx].get_legend_handles_labels()
        axes[idx].axhline(y=strong_ceiling, color='r', linestyle='--', label='Strong Ceiling')
        handles.append(plt.Line2D([], [], color='r', linestyle='--'))
        labels.append('Strong Ceiling')
        axes[idx].legend(handles=handles, labels=labels, fontsize=16)
    else:
        axes[idx].set_xlabel('Dataset', fontsize=24)
        axes[idx].tick_params(axis='x', labelsize=18)
        axes[idx].legend(fontsize=16)
    # Set x-tick labels to "No TFT", "Weak TFT", "Weak + WTS TFT"
    x_tick_labels = ['No TFT', 'Weak TFT', 'Weak +\nWTS TFT']
    axes[idx].set_xticklabels(x_tick_labels)

plt.tight_layout()
plt.savefig(f"{file_path_prefix}.png")