import json
import pandas as pd
import argparse
import os
import re
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

def extract_parameter_value(s, mode):
    # Regex to find the lambda value pattern like '0p6'
    if mode == 'wts-aux-loss' or mode == 'alpha':
        match = re.search(r'alpha-(\d+p\d+)', s)
    elif mode == 'epochs':
        match = re.search(r'epochs-(\d+)', s)
    elif mode == 'warm_up':
        match = re.search(r'warm_up-(\d+p\d+)', s)
    else:
        match = re.search(r'lambda-(\d+p\d+)', s)

    if match:
        if mode == 'epochs':
            return int(match.group(1))
        else:
            # Replace 'p' with '.' to convert to a decimal number
            lambda_value = match.group(1).replace('p', '.')
            return float(lambda_value)
    return None  # Return None if no pattern is matched

if __name__ == '__main__':
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('input_dir', type=str, help='Input directory')
    parser.add_argument("--mode", type=str, choices=['independent', 'wts-naive', 'wts-aux-loss', 'epochs', 'warm_up'], default="independent", help="Analysis mode")

    args = parser.parse_args()
    input_dir = args.input_dir
    mode = args.mode

    # Get subdirectory names from the input directory
    subdir_list = os.listdir(input_dir)
    subdir_list = [subdir for subdir in subdir_list if os.path.isdir(os.path.join(input_dir, subdir))]
    subdir_list.sort()

    # Load JSON files from input directory and store them in a DataFrame
    df = pd.DataFrame()
    for subdir in subdir_list:
        parameter_value = extract_parameter_value(subdir, mode)
        if parameter_value is None:
            continue
        subdir_path = os.path.join(input_dir, subdir)
        print(f'Processing {subdir_path} ...')
        for file in os.listdir(subdir_path):
            if file.endswith('.json'):
                print(f'  Loading {file} ...')
                with open(os.path.join(subdir_path, file)) as f:
                    data = json.load(f)
                    rows = []
                    for category, tasks in data.items():
                        for task, scores in tasks.items():
                            if mode == 'independent':
                                rows.append((parameter_value, category, task, scores['Weak Performance'], scores['Strong Performance']))
                            elif mode == 'wts-naive':
                                rows.append((parameter_value, category, task, scores['WTS-Naive']))
                            elif mode == 'wts-aux-loss' or mode == 'warm_up':
                                rows.append((parameter_value, category, task, scores['WTS-Aux-Loss']))
                            elif mode == 'epochs':
                                rows.append((parameter_value, category, task, scores['Weak Performance'], scores['WTS-Naive']))

                    if mode == 'independent':
                        df = pd.concat([df, pd.DataFrame(rows, columns=['Lambda', 'Category', 'Task', 'Weak Performance', 'Strong Performance'])])
                    elif mode == 'wts-naive':
                        df = pd.concat([df, pd.DataFrame(rows, columns=['Lambda', 'Category', 'Task', 'WTS-Naive'])])
                    elif mode == 'wts-aux-loss':
                        df = pd.concat([df, pd.DataFrame(rows, columns=['Alpha', 'Category', 'Task', 'WTS-Aux-Loss'])])
                    elif mode == 'warm_up':
                        df = pd.concat([df, pd.DataFrame(rows, columns=['Warm-Up', 'Category', 'Task', 'WTS-Aux-Loss'])])
                    elif mode == 'epochs':
                        df = pd.concat([df, pd.DataFrame(rows, columns=['Epochs', 'Category', 'Task', 'Weak Performance', 'WTS-Naive'])])


    # Average the performance scores for each lambda value and category
    if mode == 'wts-aux-loss':
        df = df.groupby(['Alpha', 'Category']).mean().reset_index()
    elif mode == 'warm_up':
        df = df.groupby(['Warm-Up', 'Category']).mean().reset_index()
    elif mode == 'epochs':
        df = df.groupby(['Epochs', 'Category']).mean().reset_index()
    else:
        df = df.groupby(['Lambda', 'Category']).mean().reset_index()
    # print(df)

    # Plot original vs adversarial performance scores for each lambda value and performance type
    if mode == 'independent':
        type_values = ['Weak Performance', 'Strong Performance']
    elif mode == 'wts-naive':
        type_values = ['WTS-Naive']
    elif mode == 'wts-aux-loss' or mode == 'warm_up':
        type_values = ['WTS-Aux-Loss']
    elif mode == 'epochs':
        type_values = ['Weak Performance', 'WTS-Naive']

    for performance_type in type_values:
        # Calculate the max - min difference for each category
        adv_max = df[df['Category'] == 'adversarial'][performance_type].max()
        adv_min = df[df['Category'] == 'adversarial'][performance_type].min()
        adv_diff = adv_max - adv_min
        orig_max = df[df['Category'] == 'original'][performance_type].max()
        orig_min = df[df['Category'] == 'original'][performance_type].min()
        orig_diff = orig_max - orig_min

        plt.figure()
        if mode == 'wts-aux-loss':
            param_values_list = df['Alpha'].unique()
        elif mode == 'warm_up':
            param_values_list = df['Warm-Up'].unique()
        elif mode == 'epochs':
            param_values_list = df['Epochs'].unique()
        else:
            param_values_list = df['Lambda'].unique()

        for parameter_value in param_values_list:
            if mode == 'wts-aux-loss':
                df_parameter = df[df['Alpha'] == parameter_value]
            elif mode == 'warm_up':
                df_parameter = df[df['Warm-Up'] == parameter_value]
            elif mode == 'epochs':
                df_parameter = df[df['Epochs'] == parameter_value]
            else:
                df_parameter = df[df['Lambda'] == parameter_value]
            x = df_parameter[df_parameter['Category'] == 'original'][performance_type]
            y = df_parameter[df_parameter['Category'] == 'adversarial'][performance_type]
            plt.scatter(x, y, s=80, color='tab:blue')
            plt.text(x + (0.01 * orig_diff), y + (0.01 * adv_diff), parameter_value, fontsize=10)
            
        plt.xlabel('Original', fontsize=16)
        plt.ylabel('Adversarial', fontsize=16)
        plt.title(performance_type, fontsize=18)
        # plt.legend(title='Lambda', fontsize=14)
        # plt.setp(plt.gca().get_legend().get_title(), fontsize=14)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.tight_layout()
        # Save the plot as a PNG file in the input directory
        plt.savefig(os.path.join(input_dir, f'{performance_type}-{mode}.png'))
