import json
import pandas as pd
import argparse
import os
# import re
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from trade_off_plot import extract_parameter_value

# sns.set_style('darkgrid')
sns.set_style('whitegrid')
sns.set_palette('colorblind')

if __name__ == '__main__':
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('input_dir', type=str, help='Input directory')
    parser.add_argument("--mode", type=str, choices=['lambda', 'alpha'], default="lambda", help="Analysis mode")

    args = parser.parse_args()
    input_dir = args.input_dir
    mode = args.mode

    # Get subdirectory names from the input directory
    subdir_list = os.listdir(input_dir)
    subdir_list = [subdir for subdir in subdir_list if os.path.isdir(os.path.join(input_dir, subdir))]
    subdir_list.sort()

    # Load JSON files from input directory and store them in a DataFrame
    df = pd.DataFrame()
    for subdir in subdir_list:
        # print(f'Processing {subdir} ...')
        parameter_value = extract_parameter_value(subdir, mode)
        if parameter_value is None:
            continue
        subdir_path = os.path.join(input_dir, subdir)
        # print(f'Processing {subdir_path} ...')
        for file in os.listdir(subdir_path):
            if file.endswith('.json'):
                # print(f'  Loading {file} ...')
                with open(os.path.join(subdir_path, file)) as f:
                    data = json.load(f)
                    rows = []
                    for category, tasks in data.items():
                        for task, scores in tasks.items():
                            rows.append((parameter_value, category, task, scores['Weak Performance'], scores['WTS-Naive'],
                                            scores['WTS-Aux-Loss'], scores['Strong Performance']))
                    # Take average of the scores over all tasks
                    if mode == 'lambda':
                        df_temp = pd.DataFrame(rows, columns=['Lambda', 'Category', 'Task', 'Weak Performance', 'WTS-Naive',
                                                              'WTS-Aux-Loss', 'Strong Performance'])
                        
                        # Aggregate performance over all tasks
                        df_temp = df_temp.groupby(['Lambda', 'Category']).mean(numeric_only=True).reset_index()
                    elif mode == 'alpha':
                        df_temp = pd.DataFrame(rows, columns=['Alpha', 'Category', 'Task', 'Weak Performance', 'WTS-Naive',
                                                              'WTS-Aux-Loss', 'Strong Performance'])
                        df_temp = df_temp.groupby(['Alpha', 'Category']).mean(numeric_only=True).reset_index()
                    
                    df = pd.concat([df, df_temp])

    # print(df)

    # Plot the scores with respect to the parameter value (lambda or alpha) for each category
    num_categories = len(df['Category'].unique())
    fig, axes = plt.subplots(num_categories, 1, figsize=(6, 6 * num_categories))        # sharex=True

    for i, category in enumerate(df['Category'].unique()):
        ax = axes[i]
        # Get strong performance for lambda = 0.3 or alpha = 0.1
        if mode == 'lambda':
            # strong_ceiling = df[(df['Category'] == category) & (df['Lambda'] == 0.3)]['Strong Performance'].mean()
            strong_ceiling = df[(df['Category'] == category) & (df['Lambda'] == 0.2)]['Strong Performance'].mean()
        elif mode == 'alpha':
            strong_ceiling = df[(df['Category'] == category)]['Strong Performance'].mean()
            weak_baseline = df[(df['Category'] == category)]['Weak Performance'].mean()
            wts_naive = df[(df['Category'] == category)]['WTS-Naive'].mean()

        if mode == 'lambda':
            sns.lineplot(x='Lambda', y='Weak Performance', data=df[df['Category'] == category], label='Weak Performance', ax=ax)
            sns.lineplot(x='Lambda', y='WTS-Naive', data=df[df['Category'] == category], label='WTS-Naive', ax=ax)
            sns.lineplot(x='Lambda', y='WTS-Aux-Loss', data=df[df['Category'] == category], label='WTS-Aux-Loss', ax=ax)
        elif mode == 'alpha':
            # sns.lineplot(x='Alpha', y='Weak Performance', data=df[df['Category'] == category], label='Weak Performance', ax=ax)
            # sns.lineplot(x='Alpha', y='WTS-Naive', data=df[df['Category'] == category], label='WTS-Naive', ax=ax)
            ax.axhline(y=weak_baseline, label='Weak Baseline', color=sns.color_palette()[0])
            ax.axhline(y=wts_naive, label='WTS-Naive', color=sns.color_palette()[1])
            sns.lineplot(x='Alpha', y='WTS-Aux-Loss', data=df[df['Category'] == category], label='WTS-Aux-Loss', ax=ax, color=sns.color_palette()[2])
            # sns.lineplot(x='Alpha', y='Strong Performance', data=df[df['Category'] == category], label='Strong Performance', ax=ax, color=sns.color_palette()[3])
        
        # Plot the strong performance ceiling
        ax.axhline(y=strong_ceiling, color='r', linestyle='--', label='Strong Ceiling')
        # ax.set_title(category)
        ax.set_xlabel('Lambda' if mode == 'lambda' else 'Max Alpha', fontsize=24)
        ax.set_xticks(np.arange(0, 1.01, 0.2))
        if category == 'adversarial':
            # add an arrow to show which direction is better
            ax.set_ylabel('Adversarial Robustness (%)\n(better \u2192)', fontsize=24)
        elif category == 'original':
            ax.set_ylabel('Task Performance (%)\n(better \u2192)', fontsize=24)
        else:
            ax.set_ylabel('Performance', fontsize=24)
        ax.legend(fontsize=16)
        ax.tick_params(axis='both', labelsize=18)

    plt.tight_layout()
    plt.savefig(f'{input_dir}/{mode}_sensitivity_plots.png')
    print(f'Plot saved to {input_dir}/{mode}_sensitivity_plots.png')
