import pandas as pd
import json
import os
import argparse
import glob

# Parse the arguments
parser = argparse.ArgumentParser()
parser.add_argument('input_path_prefix', type=str, help='The prefix of the input JSON files.')
parser.add_argument('save_dir', type=str, help='The directory to save the result.')
parser.add_argument('title', type=str, help='Name for the average performance.')
args = parser.parse_args()

input_path_prefix = args.input_path_prefix
save_dir = args.save_dir
title = args.title

# Create the save directory if it does not exist
os.makedirs(save_dir, exist_ok=True)

# Find all JSON files with the given prefix using glob
input_paths = glob.glob(input_path_prefix + '*.json')
input_paths = sorted(input_paths)   # Sort the paths

for input_path in input_paths:
    # print("Processing:", input_path)
    input_filename = os.path.basename(input_path)

    # Load the JSON data
    with open(input_path, 'r') as file:
        data = json.load(file)

    df = pd.DataFrame()
    rows = []
    for category, tasks in data.items():
        for task, scores in tasks.items():
            rows.append((category, task, scores['Weak Performance'], scores['Strong Performance'], scores['WTS-Naive'], scores['WTS-Aux-Loss']))

    df = pd.DataFrame(rows, columns=['Category', 'Task', 'Weak Performance', 'Strong Performance', 'WTS-Naive', 'WTS-Aux-Loss'])
    # Keep rows for sst2 task
    # df = df[df['Task'] == 'rte']
    # print(df)
    # exit()

    # Calculate the average performance over all tasks for each category
    df = df.groupby('Category').mean().reset_index()

    # Insert the result to a JSON file
    save_path = os.path.join(save_dir, input_filename)
    if os.path.exists(save_path):
        with open(save_path, 'r') as file:
            res_data = json.load(file)
    else:
        res_data = {}

    for category in df['Category']:
        df_cat = df[df['Category'] == category].drop('Category', axis='columns')
        if category in res_data:
            res_data[category][title] = df_cat.to_dict(orient='records')[0]
        else:
            res_data[category] = {title: df_cat.to_dict(orient='records')[0]}

    # Sort the dictionary
    res_data = dict(sorted(res_data.items()))
    for key, value in res_data.items():
        res_data[key] = dict(sorted(value.items()))

    with open(save_path, 'w') as file:
        json.dump(res_data, file, indent=2)

    # print("Average performance saved to:", save_path)
