import os
import re
import pandas as pd
import numpy as np
from scipy.stats import sem, t
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import argparse

plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "Times New Roman", "serif"]

# Parse arguments
parser = argparse.ArgumentParser(description="Results' directory")
parser.add_argument('cifar_dir_path', type=str, help='cifar directory path to process')
parser.add_argument('mnist_dir_path', type=str, help='mnist directory path to process')
parser.add_argument('--output_dir', type=str, help='Output directory to save plots')
parser.add_argument('--agg', type=str, help='aggregation method')
parser.add_argument('--alpha', type=str, help='alpha')
# parser.add_argument('--f', type=int, help='f')
# parser.add_argument('--momentum', type=str, help='momentum')
parser.add_argument(
        "--mixing",
        action="store_true",
        default=True,
    )
args = parser.parse_args()

confidence = 0.95

# Create output directory if not exist
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir, exist_ok=True)

attacks = ["No Attack", "ALIE", "LF", "SF", "Mimic", "MinMax", "MinSum", "FOE", "$\\textsc{Jump}$ (ours)"]

# Specify the directory
cifar_dir_path = args.cifar_dir_path

# List to store all data
cifar_data_list = []

# Iterate over the directory of cifar10
for dir_name in os.listdir(cifar_dir_path):
    directory = os.path.join(cifar_dir_path, dir_name)
    if os.path.isdir(directory):
        if len(os.listdir(directory)) == 0:
            continue
        # Extract parameter values
        match = re.search(r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize0_nlpobj1.0_mix(\w+)_clip5.0_s0_seed(\d+)', dir_name) or re.search(r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize1_nlpobj1.0_mix(\w+)_clip5.0_s0_seed(\d+)', dir_name)
        if match:
            agg, attack, alpha, n, f, momentum, mixing, seed = match.groups()
            if attack == 'mimic':
                attack = 'Mimic'
            if attack == 'SSNLP':
                attack = '$\\textsc{Jump}$ (ours)'
            if attack == 'BF':
                attack = 'SF'
            if attack == 'IPM':
                attack = 'FOE'
            if attack == 'NA':
                attack = 'No Attack'
            alpha = alpha
            f = int(f)
            momentum = momentum
            mixing = True if mixing == 'True' else False
            seed = int(seed)

            if attack in attacks:
                # Read and parse the file
                with open(os.path.join(directory, 'accs_test.txt')) as file:
                    content = file.read().splitlines()
                    rounds = [(int(line.split()[0])-1)*50 for line in content]
                    accuracies = [float(line.split()[1]) for line in content]

                    for round_, accuracy in zip(rounds, accuracies):
                        # Append data to the list
                        cifar_data_list.append([agg, attack, alpha, f, momentum, mixing, seed, round_, accuracy])

# Create a DataFrame
cifar_data = pd.DataFrame(cifar_data_list, columns=['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'seed', 'epoch', 'accuracy'])

# Group the data and calculate mean, std, and count
cifar_grouped_data = cifar_data.groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'epoch']).agg({
    'accuracy': ['mean', 'std', 'count']
}).reset_index()

# Rename columns for easier access
cifar_grouped_data.columns = [' '.join(col).strip() for col in cifar_grouped_data.columns.values]

# Calculate standard error
cifar_grouped_data['std_err'] = cifar_grouped_data['accuracy std'] / np.sqrt(cifar_grouped_data['accuracy count'])

# Calculate confidence interval
cifar_grouped_data['conf_interval'] = cifar_grouped_data['std_err'] * t.ppf((1 + confidence) / 2, cifar_grouped_data['accuracy count'] - 1)


# Specify the directory
mnist_dir_path = args.mnist_dir_path

# List to store all data
mnist_data_list = []

# Iterate over the directory
for dir_name in os.listdir(mnist_dir_path):
    directory = os.path.join(mnist_dir_path, dir_name)
    if os.path.isdir(directory):
        if len(os.listdir(directory)) == 0:
            continue
        # Extract parameter values
        match = re.search(
            r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize0_nlpobj1.0_mix(\w+)_s0_seed(\d+)',
            dir_name) or re.search(
            r'(\w+)_(\w+)_niid(\d+\.\d+)_n(\d+)_f(\d+)_m(\d\.\d+)_nlpsize1_nlpobj1.0_mix(\w+)_s0_seed(\d+)', dir_name)
        if match:
            agg, attack, alpha, n, f, momentum, mixing, seed = match.groups()
            if attack == 'mimic':
                attack = 'Mimic'
            if attack == 'SSNLP':
                attack = '$\\textsc{Jump}$ (ours)'
            if attack == 'BF':
                attack = 'SF'
            if attack == 'IPM':
                attack = 'FOE'
            if attack == 'NA':
                attack = 'No Attack'
            alpha = alpha
            f = int(f)
            momentum = momentum
            mixing = True if mixing == 'True' else False
            seed = int(seed)

            if attack in attacks:
                # Read and parse the file
                with open(os.path.join(directory, 'accs_test.txt')) as file:
                    content = file.read().splitlines()
                    rounds = [(int(line.split()[0]) - 1) * 32 for line in content]
                    accuracies = [float(line.split()[1]) for line in content]

                    for round_, accuracy in zip(rounds, accuracies):
                        # Append data to the list
                        mnist_data_list.append([agg, attack, alpha, f, momentum, mixing, seed, round_, accuracy])

# Create a DataFrame
mnist_data = pd.DataFrame(mnist_data_list,
                    columns=['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'seed', 'epoch', 'accuracy'])

# Group the data and calculate mean, std, and count
mnist_grouped_data = mnist_data.groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing', 'epoch']).agg({
    'accuracy': ['mean', 'std', 'count']
}).reset_index()

# Rename columns for easier access
mnist_grouped_data.columns = [' '.join(col).strip() for col in mnist_grouped_data.columns.values]

# Calculate standard error
mnist_grouped_data['std_err'] = mnist_grouped_data['accuracy std'] / np.sqrt(mnist_grouped_data['accuracy count'])

# Calculate confidence interval
mnist_grouped_data['conf_interval'] = mnist_grouped_data['std_err'] * t.ppf((1 + confidence) / 2,
                                                                mnist_grouped_data['accuracy count'] - 1)


# Plotting
fig = plt.figure(figsize=(16, 20))
fig.subplots_adjust(left=0.055, right=0.945, top=0.97, bottom=0.07)
gs = gridspec.GridSpec(5, 3, height_ratios=[1, 1, 1, 1, 1], width_ratios=[1, 1, 1], wspace=0.15, hspace=0.15)


# Line style for different attacks
line_styles = {
    "No Attack": 'solid',
    "ALIE": 'dashed',
    "FOE": 'dashdot',
    "SF": (0, (1, 0.8)),
    "LF": (0, (1, 1)),
    "MinMax": (0, (3, 1, 1, 1)),
    "MinSum": (0, (3, 1, 1, 1)),
    "Mimic": (0, (3, 1, 1, 1, 1, 1)),
    "$\\textsc{Jump}$ (ours)": 'solid'
    }


# color for different attacks
colors = {
    "No Attack": 'blue',
    "ALIE": 'm',
    "FOE": 'orange',
    "SF": 'r',
    "LF": 'g',
    "MinMax": 'c',
    "MinSum": 'k',
    "Mimic": 'y',
    "$\\textsc{Jump}$ (ours)": 'black'
    }

attacks = ["No Attack", "ALIE", "LF", "SF", "Mimic", "MinMax", "$\\textsc{Jump}$ (ours)"]

idx = 0
f = 1
momentum = '0.0'
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color, label=attack)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
ax.set_title('$f=1$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps', fontsize=15)
ax.set_ylabel('Accuracy', fontsize=20)
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)

idx += 1
f = 3
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)

ax.set_title('$f=3$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps')
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)

idx += 1
f = 5
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)

ax.set_title('$f=5$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps')
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)


idx += 1
f = 1
momentum = '0.9'
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_title('$f=1$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps', fontsize=15)
ax.set_ylabel('Accuracy', fontsize=20)
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)

idx += 1
f = 3
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)

# ax.set_title('$f=3$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps')
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)

idx += 1
f = 5
ax = fig.add_subplot(gs[idx])
for name, group in mnist_grouped_data[
    (
        (mnist_grouped_data['agg'] == args.agg) &
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == f) &
        (mnist_grouped_data['momentum'] == momentum) &
        (mnist_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (mnist_grouped_data['alpha'] == args.alpha) &
        (mnist_grouped_data['f'] == 0) &
        (mnist_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)

# ax.set_title('$f=5$', fontsize=20, pad=10)
# ax.set_xlabel('Number of steps')
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 100)
ax.set_xlim(0, 800)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 801, 200))
ax.yaxis.set_ticks(np.arange(0, 101, 20))
ax.grid(True)




idx += 1
f = 1
momentum = '0.0'
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=20, labelpad=14)
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 3
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 5
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)


idx += 1
f = 1
momentum = '0.9'
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=20, labelpad=14)
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 3
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 5
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
# ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)



idx += 1
f = 1
momentum = '0.99'
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
ax.set_xlabel('Number of steps', fontsize=20)
ax.set_ylabel('Accuracy', fontsize=20, labelpad=14)
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 3
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)

idx += 1
f = 5
ax = fig.add_subplot(gs[idx])
for name, group in cifar_grouped_data[
    (
        (cifar_grouped_data['agg'] == args.agg) &
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == f) &
        (cifar_grouped_data['momentum'] == momentum) &
        (cifar_grouped_data['mixing'] == args.mixing)
    ) |
    (
        (cifar_grouped_data['alpha'] == args.alpha) &
        (cifar_grouped_data['f'] == 0) &
        (cifar_grouped_data['momentum'] == momentum)
    )].groupby(['agg', 'attack', 'alpha', 'f', 'momentum', 'mixing']):
    agg, attack, alpha, f, momentum, mixing = name
    linestyle = line_styles.get(attack, 'solid')
    color = colors.get(attack, 'r')
    ax.plot(group['epoch'], group['accuracy mean'], linestyle=linestyle, color=color)
    ax.fill_between(group['epoch'], (group['accuracy mean'] - group['conf_interval']), (group['accuracy mean'] + group['conf_interval']), alpha=0.2, color=color, edgecolor=None)
ax.set_xlabel('Number of steps', fontsize=20)
# ax.set_ylabel('Accuracy')
ax.set_ylim(0, 90)
ax.set_xlim(0, 1500)
ax.tick_params(axis='both', labelsize=20)
ax.xaxis.set_ticks(np.arange(0, 1501, 300))
ax.yaxis.set_ticks(np.arange(0, 91, 20))
ax.grid(True)




fig.legend(loc='lower center', ncol=9, fontsize=20, labelspacing=1, columnspacing=1, handlelength=1)

plt.tight_layout()
# plt.show()
output_filename = f"{args.agg}_{args.alpha}_{args.mixing}.pdf"
# Full path to output file
output_path = os.path.join(args.output_dir, output_filename)
# Save the plot
fig.savefig(output_path, format='pdf')



