import os
import re
import pandas as pd
import numpy as np
from scipy.stats import sem, t
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import argparse

plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "Times New Roman", "serif"]

# Parse arguments
parser = argparse.ArgumentParser(description="Results' directory")
parser.add_argument('dir_path', type=str, help='mnist directory path to process')
parser.add_argument('--output_dir', type=str, help='Output directory to save plots')
# parser.add_argument('--alpha', type=str, help='alpha')
# parser.add_argument('--f', type=int, help='f')
# parser.add_argument('--momentum', type=str, help='momentum')
# parser.add_argument(
#         "--mixing",
#         action="store_true",
#         default=True,
#     )
args = parser.parse_args()
dir_path = args.dir_path

attacks = [1, 4, 16, 32]

# List to store all data
acc_data_list = []

# Iterate over the directory
for dir_name in os.listdir(dir_path):
    directory = os.path.join(dir_path, dir_name)
    if os.path.isdir(directory):
        if len(os.listdir(directory)) == 0:
            continue
        # Extract parameter values
        # print(dir_name)
        match = re.search(
            r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize(\d+)_nlpobj1.0_mix(\w+)_s0_seed(\d+)',
            dir_name)
        # print(match)
        if match:
            # print('match')
            agg, attack, alpha, n, f, momentum, nlpsize, mixing, seed = match.groups()
            if attack == 'SSNLP':
                attack = int(nlpsize)
            alpha = alpha
            f = int(f)
            momentum = momentum
            mixing = True if mixing == 'True' else False
            seed = int(seed)

            # Read and parse the file
            if not os.path.exists(os.path.join(directory, 'accs_test.txt')):
                continue
            with open(os.path.join(directory, 'accs_test.txt')) as file:
                content = file.read().splitlines()
                rounds = [int(line.split()[0]) for line in content]
                accuracies = [float(line.split()[1]) for line in content]
                max_accuracy = max(accuracies)

            # Append data to the list
            # if alpha != '10.0':
            acc_data_list.append([agg, attack, alpha, f, momentum, mixing, seed, max_accuracy])

# Create a DataFrame
acc_data = pd.DataFrame(acc_data_list,
                        columns=['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'seed', 'max_accuracy'])

# Group the data and calculate mean, std, and count
acc_grouped_data = acc_data.groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']).agg({
    'max_accuracy': ['mean', 'std', 'count']
}).reset_index()

# Rename columns for easier access
acc_grouped_data.columns = [' '.join(col).strip() for col in acc_grouped_data.columns.values]

# Calculate standard error
acc_grouped_data['std_err'] = acc_grouped_data['max_accuracy std'] / np.sqrt(acc_grouped_data['max_accuracy count'])

# Calculate confidence interval
confidence = 0.95
acc_grouped_data['conf_interval'] = acc_grouped_data['std_err'] * t.ppf((1 + confidence) / 2,
                                                                        acc_grouped_data['max_accuracy count'] - 1)




# List to store all data
time_data_list = []

# max_accuracy_data = {}

# Iterate over the directory
for dir_name in os.listdir(dir_path):
    directory = os.path.join(dir_path, dir_name)
    if os.path.isdir(directory):
        if len(os.listdir(directory)) == 0:
            continue
        # Extract parameter values
        # print(dir_name)
        match = re.search(r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize(\d+)_nlpobj1.0_mix(\w+)_s0_seed(\d+)', dir_name)
        # print(match)
        if match:
            # print('match')
            agg, attack, alpha, n, f, momentum, nlpsize, mixing, seed = match.groups()
            if attack == 'SSNLP':
                attack = int(nlpsize)
            alpha = alpha
            f = int(f)
            momentum = momentum
            mixing = True if mixing == 'True' else False
            seed = int(seed)

            # Read and parse the file
            if not os.path.exists(os.path.join(directory, 'time.txt')):
                continue
            with open(os.path.join(directory, 'time.txt')) as file:
                content = file.read().splitlines()
                accuracies = [float(line.split()[0]) for line in content]
                max_accuracy = max(accuracies)

            # Append data to the list
            # if seed != 1 and seed != 2:
            time_data_list.append([agg, attack, alpha, f, momentum, mixing, seed, max_accuracy])

# Create a DataFrame
data_df = pd.DataFrame(time_data_list, columns=['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'seed', 'max_accuracy'])

# Merge the DataFrame with itself on shared columns to compute the ratio for each seed
merged_df = data_df.merge(data_df, on=['agg', 'alpha', 'f', 'momentum', 'mixing', 'seed'])

# Filter only rows where one attack is SSNLP (int type) and the other is 'BF'
filtered_df = merged_df[(merged_df['attack_x'].apply(lambda x: isinstance(x, int)) & (merged_df['attack_y'] == 'BF'))]

filtered_df = filtered_df.copy()

# Compute the ratio for each seed
filtered_df['ratio'] = filtered_df['max_accuracy_x'] / filtered_df['max_accuracy_y']

# Now, we can compute mean and confidence interval for the ratio
time_grouped_data = filtered_df.groupby(['agg','attack_x',  'alpha', 'f', 'momentum', 'mixing']).agg({
    'ratio': ['mean', 'std', 'count']
}).reset_index()

# Flatten the MultiIndex for easier column access
time_grouped_data.columns = [' '.join(col).strip() for col in time_grouped_data.columns.values]

# Calculate standard error
time_grouped_data['std_err'] = time_grouped_data['ratio std'] / time_grouped_data['ratio count'].apply(lambda x: x**0.5)

# Calculate confidence interval
confidence = 0.95
time_grouped_data['conf_interval'] = time_grouped_data['std_err'] * t.ppf((1 + confidence) / 2, time_grouped_data['ratio count'] - 1)



# Plotting
fig = plt.figure(figsize=(16, 6))
gs = gridspec.GridSpec(1, 2, height_ratios=[1], width_ratios=[1, 1], wspace=0.4, hspace=0)

agg = 'cm'

idx = 0
momentum = '0.9'
ax = fig.add_subplot(gs[idx])
ax.set_title('CM', fontsize=20)
ax.set_xlabel('$\\tau$', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=20, color='blue')
ax.set_ylim(0, 80)
ax.set_xlim(0, 34)
ax.tick_params(axis='both', labelsize=20)
ax.tick_params(axis='y')
ax.grid(True)
ax.xaxis.set_ticks(np.arange(0, 33, 8))
ax.yaxis.set_ticks(np.arange(0, 81, 20))

ax2 = ax.twinx()
ax2.yaxis.set_ticks(np.arange(0, 401, 100))
ax2.set_ylabel('Runtime ratio', fontsize=20, color='green')
ax2.tick_params(axis='y', labelsize=20)
ax2.set_ylim(bottom=0)
maxacc = []
for name, group in acc_grouped_data[
    (
            (acc_grouped_data['agg'] == agg) &
            (acc_grouped_data['alpha'] == '0.1') &
            (acc_grouped_data['f'] == 3) &
            (acc_grouped_data['momentum'] == '0.9') &
            (acc_grouped_data['mixing'] == True) &
            (acc_grouped_data['attack'] != 'BF')
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    print('tau')
    agg, attack, alpha, f, momentum, mixing = name
    print(name)
    print(group['max_accuracy mean'])
    ax.errorbar(attack, group['max_accuracy mean'], yerr=group['conf_interval'], fmt='s', color='blue', ecolor='c',
                capsize=7, linewidth=3, elinewidth=2, capthick=3)
    maxacc.append(group['max_accuracy mean'].item())
line1 = ax.errorbar([], [], yerr=[], fmt='s', color='blue', ecolor='c',
                capsize=7, linewidth=3, elinewidth=2, capthick=3, label="Maximal accuracy")
ax.plot(attacks, maxacc, linestyle='dashed', color='blue')


timeseries = []
for name, group in time_grouped_data[
    (
            (time_grouped_data['agg'] == agg) &
            (time_grouped_data['alpha'] == '0.1') &
            (time_grouped_data['f'] == 3) &
            (time_grouped_data['momentum'] == '0.9') &
            (time_grouped_data['mixing'] == True)
    )].groupby(['agg', 'attack_x', 'alpha', 'f', 'momentum', 'mixing']):
    print('tau')
    agg, attack, alpha, f, momentum, mixing = name
    print(name)
    print(group['ratio mean'])
    ax2.errorbar(attack, group['ratio mean'], yerr=group['conf_interval'], fmt='o', color='green', ecolor='orange',
                 capsize=7, linewidth=2, elinewidth=1.5, capthick=2.5)
    timeseries.append(group['ratio mean'].item())
line2 = ax2.errorbar([], [], yerr=[], fmt='o', color='green', ecolor='orange',
                 capsize=7, linewidth=2, elinewidth=1.5, capthick=2.5, label="Runtime ratio")
ax2.plot(attacks, timeseries, linestyle='dashdot', color='green')

lines = [line1, line2]
labels = [l.get_label() for l in lines]

ax.legend(lines, labels, loc='lower right', fontsize=18)

agg = 'rfa'

idx += 1
momentum = '0.9'
ax = fig.add_subplot(gs[idx])
ax.set_title('GM', fontsize=20)
ax.set_xlabel('$\\tau$', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=20, color='blue')
ax.set_ylim(0, 80)
ax.set_xlim(0, 34)
ax.tick_params(axis='both', labelsize=20)
ax.tick_params(axis='y')
ax.grid(True)
ax.xaxis.set_ticks(np.arange(0, 33, 8))
ax.yaxis.set_ticks(np.arange(0, 81, 20))

ax2 = ax.twinx()
ax2.set_ylabel('Runtime ratio', fontsize=20, color='green')
ax2.tick_params(axis='y', labelsize=20)
ax2.yaxis.set_ticks(np.arange(0, 601, 150))
ax2.set_ylim(bottom=0)
maxacc = []
for name, group in acc_grouped_data[
    (
            (acc_grouped_data['agg'] == agg) &
            (acc_grouped_data['alpha'] == '0.1') &
            (acc_grouped_data['f'] == 3) &
            (acc_grouped_data['momentum'] == '0.9') &
            (acc_grouped_data['mixing'] == True) &
            (acc_grouped_data['attack'] != 'BF')
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    print('tau')
    agg, attack, alpha, f, momentum, mixing = name
    print(name)
    print(group['max_accuracy mean'])
    ax.errorbar(attack, group['max_accuracy mean'], yerr=group['conf_interval'], fmt='s', color='blue', ecolor='c',
                capsize=7, linewidth=3, elinewidth=2, capthick=3)
    maxacc.append(group['max_accuracy mean'].item())
line1 = ax.errorbar([], [], yerr=[], fmt='s', color='blue', ecolor='c',
                capsize=7, linewidth=3, elinewidth=2, capthick=3, label="Maximal accuracy")
ax.plot(attacks, maxacc, linestyle='dashed', color='blue')


timeseries = []
for name, group in time_grouped_data[
    (
            (time_grouped_data['agg'] == agg) &
            (time_grouped_data['alpha'] == '0.1') &
            (time_grouped_data['f'] == 3) &
            (time_grouped_data['momentum'] == '0.9') &
            (time_grouped_data['mixing'] == True)
    )].groupby(['agg', 'attack_x', 'alpha', 'f', 'momentum', 'mixing']):
    print('tau')
    agg, attack, alpha, f, momentum, mixing = name
    print(name)
    print(group['ratio mean'])
    ax2.errorbar(attack, group['ratio mean'], yerr=group['conf_interval'], fmt='o', color='green', ecolor='orange',
                 capsize=7, linewidth=2, elinewidth=1.5, capthick=2.5)
    timeseries.append(group['ratio mean'].item())
ax2.plot(attacks, timeseries, linestyle='dashdot', color='green')
line2 = ax2.errorbar([], [], yerr=[], fmt='o', color='green', ecolor='orange',
                 capsize=7, linewidth=2, elinewidth=1.5, capthick=2.5, label="Runtime ratio")

lines = [line1, line2]
labels = [l.get_label() for l in lines]

ax.legend(lines, labels, loc='lower right', fontsize=18)

plt.tight_layout()
# plt.show()
output_filename = f"tau_01_1.pdf"
# Full path to output file
output_path = os.path.join(args.output_dir, output_filename)
# Save the plot
fig.savefig(output_path, format='pdf')