from utils_toy import get_args
from utils_toy import main
from utils_toy import EXP_DIR, DATA_DIR
from codes.tasks.nonconvex_2d import read_txt, TwoMinima, CombinedGaussian

# from toydata import data_generate, read_txt, LossSurface
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.ticker import MaxNLocator, MultipleLocator, FormatStrFormatter
import torch

plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["serif", "Times", "Times New Roman"]

import numpy as np
import torch
import os

args = get_args()
LOG_DIR = EXP_DIR + "test/"

if args.identifier:
    LOG_DIR += f"{args.identifier}/"
elif args.debug:
    LOG_DIR += "debug/"
else:
    LOG_DIR += f"n{args.n}_f0NA_LT_m0/"

INP_DIR = LOG_DIR
OUT_DIR = LOG_DIR + "output/"
if args.dev_type != 'unit_vec':
    filename = f"toy2d_{args.loss}_{args.agg}_{args.attack}_{args.dev_type}_niid{args.noniid}_n{args.n}_f{args.f}_nlpobj{args.nlpobj}_nlpsize{args.nlpsize}_initp{args.initp[0]}_{args.initp[1]}_lr{args.LR}_iter{args.EPOCH}_seed{args.seed}"
else:
    filename = f"toy2d_{args.loss}_{args.agg}_{args.attack}_niid{args.noniid}_n{args.n}_f{args.f}_nlpobj{args.nlpobj}_nlpsize{args.nlpsize}_initp{args.initp[0]}_{args.initp[1]}_lr{args.LR}_iter{args.EPOCH}_seed{args.seed}"
LOG_DIR += filename

foldername = filename

save_dir = EXP_DIR + 'images_2dnonconvex/' + foldername
if not os.path.exists(save_dir):
    os.makedirs(save_dir, exist_ok=True)

# Number of iterations = 4500
MAX_BATCHES_PER_EPOCH = 1


# center = data_generate(h=args.n - args.f, b=args.f, mu=mu, std=std, seed=args.seed)


def save_txt(data, path):
    size = len(data)
    f = open(path, 'w', encoding='utf-8')
    for i in range(size):
        f.writelines(data[i] + '\n')
    f.close()


def plot_function_and_trajectory(loss_func, trajectory, xlim=(-3, 3), ylim=(-3, 2)):
    x = np.linspace(xlim[0], xlim[1], 400)
    y = np.linspace(ylim[0], ylim[1], 400)
    X, Y = np.meshgrid(x, y)

    # Compute loss values
    Z = np.array([[loss_func(torch.tensor([xi, yi])) for xi in x] for yi in y])

    # Plot
    fig = plt.figure(figsize=(8, 6))
    plt.contourf(X, Y, Z, 50, cmap='jet')
    plt.colorbar()
    trajectory = np.array(trajectory)
    plt.plot(trajectory[:, 0], trajectory[:, 1], '-o', color='red', markersize=3, linewidth=1.5)
    plt.title('Trajectory on Loss Surface')
    plt.xlabel('x')
    plt.ylabel('y')
    # plt.show()
    return fig


def subplot_function_and_trajectory(ax, loss_func, trajectory, gradients, color, idx, method, agg, xlim=(-3, 3),
                                    ylim=(-3, 3)):
    x = np.linspace(xlim[0], xlim[1], 300)
    y = np.linspace(ylim[0], ylim[1], 200)
    X, Y = np.meshgrid(x, y)
    Z = np.array([[loss_func(torch.tensor([xi, yi])) for xi in x] for yi in y])

    contour = ax.contour(X, Y, Z, 25)
    if idx == 0:
        ax.clabel(contour, inline=True, fontsize=7)

    trajectory = np.array(trajectory)

    ax.scatter(trajectory[0][0], trajectory[0][1], marker='s', color='black', s=30, label='Initial point', zorder=2)
    ax.plot(trajectory[:, 0], trajectory[:, 1], '-o', color='red', markersize=4, linewidth=1.5, label='Trajectory',
            zorder=3)

    grad_index_dict = {}
    if agg == 'rfa':
        grad_index_dict = {
            "No Attack": [],
            "ALIE": [0, 2, 80, 99],
            "MinMax ($-\overline{g}^{\mathcal{H}}_t$)": [0, 5, 10, 20, 99],
            "MinSum ($-\overline{g}^{\mathcal{H}}_t$)": [0, 5, 10, 20, 99],
            'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 5, 10, 20, 99],
            'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 5, 10, 20, 99],
            'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 5, 10, 20, 99],
            'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 5, 10, 20, 99],
            "Mimic": [0, 1, 2, 3],
            "SF": [0, 70, 90],
            "FOE": [0, 10, 90],
            'Solve $(\mathrm{P})$': [0, 1, 99],
            '$\\textsc{Jump}$-100': [0, 1, 99],
            '$\\textsc{Jump}$-10': [0, 11, 12, 99],
            '$\\textsc{Jump}$-1': [0, 56, 57, 99],
        }
    elif agg == 'krum':
        grad_index_dict = {
            "No Attack": [],
            "ALIE": [0, 6, 99],
            "MinMax ($-\overline{g}^{\mathcal{H}}_t$)": [0, 50, 80, 99],
            "MinSum ($-\overline{g}^{\mathcal{H}}_t$)": [0, 50, 80, 99],
            'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 50, 75, 99],
            'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 50, 75, 99],
            'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 50, 80, 99],
            'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 50, 80, 99],
            "Mimic": [0, 1, 2, 3],
            "SF": [0, 50, 80, 99],
            "FOE": [0, 50, 80, 99],
            'Solve $(\mathrm{P})$': [0, 37, 38, 99],
            '$\\textsc{Jump}$-100': [0, 40, 41, 99],
            '$\\textsc{Jump}$-10': [0, 44, 45, 99],
            '$\\textsc{Jump}$-1': [0, 56, 57, 99],
        }
    elif agg == 'tm' or 'cm':
        grad_index_dict = {
            "No Attack": [],
            "ALIE": [0, 2, 80, 99],
            "MinMax ($-\overline{g}^{\mathcal{H}}_t$)": [0, 40, 80, 90, 99],
            "MinSum ($-\overline{g}^{\mathcal{H}}_t$)": [0, 40, 80, 90, 99],
            'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 40, 70, 99],
            'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': [0, 40, 70, 99],
            'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 3, 10, 20, 99],
            'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': [0, 3, 10, 20, 99],
            "Mimic": [0, 1, 2, 3],
            "SF": [0, 80, 99],
            "FOE": [0, 60, 99],
            'Solve $(\mathrm{P})$': [0, 27, 28, 99],
            '$\\textsc{Jump}$-100': [0, 8, 9, 99],
            '$\\textsc{Jump}$-10': [0, 30, 31, 99],
            '$\\textsc{Jump}$-1': [0, 56, 57, 70],
        }
    a_i = 0
    # Add gradient arrows
    for index in grad_index_dict[method]:
        try:
            grad_len = np.linalg.norm(gradients[index])
            if grad_len != 0:
                normalized_grad = gradients[index] / grad_len
            else:
                normalized_grad = gradients[index]
            fixed_length_grad = normalized_grad * (1 / (2 + 3 * 2 ** (-grad_len)))
            end_x = trajectory[index, 0] - fixed_length_grad[0]
            end_y = trajectory[index, 1] - fixed_length_grad[1]
            ax.annotate('', (end_x, end_y), (trajectory[index, 0], trajectory[index, 1]),
                        arrowprops=dict(arrowstyle="->", color='blue', linewidth=1), zorder=4)
            # if method != "MinMax" and method != "MinSum" and method != "SF":
            #     offset = 0.2
            #     text_x = (end_x + trajectory[index, 0]) / 2
            #     text_y = (end_y + trajectory[index, 1]) / 2 + offset
            #     ax.text(text_x, text_y, '$a_{' + str(index) + '}$', color='blue', ha="center", va="bottom", fontsize=12, zorder=5)

            a_i += 1

        except:
            print('grad < point', index)

    if idx == 0:
        ax.legend(fontsize=14)

    ax.tick_params(axis='both', labelsize=18)
    ax.set_xlabel('x', fontsize=18)
    ax.set_ylabel('y', fontsize=18)
    # ax.set_xticks(np.arange(xlim[0], xlim[1] + 1, 2))
    # ax.set_yticks(np.arange(ylim[0], ylim[1] + 1, 2))
    return contour


def read_data_from_files(prefix):
    try:
        grads = np.loadtxt(EXP_DIR + 'images_2dnonconvex/' + f'{prefix}/byz_grads.txt')
    except:
        print(prefix, 'No byz_grads file')
        grads = []
    losses = np.loadtxt(EXP_DIR + 'images_2dnonconvex/' + f'{prefix}/losses.txt')
    points = np.loadtxt(EXP_DIR + 'images_2dnonconvex/' + f'{prefix}/points.txt')
    return grads, losses, points


aps, sgs, xs, ys = read_txt(DATA_DIR)
loss_funcs = []
for i in range(args.n - args.f):
    if args.loss == 'two':
        loss_funcs.append(TwoMinima())
    elif args.loss == 'gaussian':
        loss_funcs.append(CombinedGaussian())

loss_func = lambda xy: sum(
    [fun(xy, torch.tensor([(aps[i], sgs[i], xs[i], ys[i])])) for i, fun in enumerate(loss_funcs)]) / len(loss_funcs)

if not args.plot:
    points, grads = main(args, LOG_DIR, args.EPOCH, MAX_BATCHES_PER_EPOCH, args.initp)
    print("points", points)
    print("grads", grads)

    ppath = []
    losses = []
    pgrads = []
    points.insert(0, args.initp)
    for t, p in enumerate(points):
        ppath.append(str(p[0]) + ' ' + str(p[1]))
        if t < len(grads):
            pgrads.append(str(grads[t][0]) + ' ' + str(grads[t][1]))
        losses.append(str(t) + ' ' + str(loss_func(torch.tensor([p[0], p[1]])).item()))
    save_txt(ppath, save_dir + '/points.txt')
    save_txt(pgrads, save_dir + '/grads.txt')
    save_txt(losses, save_dir + '/losses.txt')

    trajectory = [args.initp]
    for p in points:
        trajectory.append(p)
    # print(trajectory)

    print(loss_func(torch.tensor([trajectory[-1][0], trajectory[-1][1]])))
    ls_fig = plot_function_and_trajectory(loss_func, trajectory, xlim=(-3, 3), ylim=(-3, 3))

    fig_dir = EXP_DIR + 'images_2dnonconvex/' + filename + '.pdf'
    ls_fig.savefig(fig_dir, format='pdf')
    print('{} saved.'.format(fig_dir))

else:
    agg = args.agg
    lr = args.LR
    initx = args.initp[0]
    inity = args.initp[1]
    epoch = args.EPOCH
    n = args.n
    f = args.f
    ################################################################################3
    methods = ['Solve $(\mathrm{P})$', '$\\textsc{Jump}$-100', '$\\textsc{Jump}$-10', '$\\textsc{Jump}$-1']
    colors = ['orange', 'r', 'g', 'c', 'b', 'm', 'y', 'k', 'blue']
    fig = plt.figure(figsize=(15, 7))

    gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 1], wspace=0.15, hspace=0.15)

    all_losses = []

    idx = 0
    prefix = f'toy2d_two_avg_NA_niidFalse_n{n - f}_f0_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
    grads, losses, points = read_data_from_files(prefix)

    ax = fig.add_subplot(gs[idx])
    if idx % 3 != 0:  # Hide y-axis if it's not the leftmost plot
        ax.get_yaxis().set_visible(False)
    if idx < 3:  # Hide x-axis if it's not the bottom plots
        ax.get_xaxis().set_visible(False)
    ax.set_title(f'No Attack', fontsize=18)
    contour_obj = subplot_function_and_trajectory(ax, loss_func, points, grads, colors[idx], idx, 'No Attack', agg,
                                                  xlim=(-2, 2.5),
                                                  ylim=(-1.5, 1.5))
    all_losses.append(losses[:, 1])

    idx += 1
    for method in methods:
        # if idx == 3:  # Skip the center plot for loss graph
        #     idx += 1
        if method == '$\\textsc{Jump}$-100':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        if method == 'Solve $(\mathrm{P})$':
            prefix = f'toy2d_two_{agg}_NLP_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-10':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize10_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-1':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize1_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'BF':
            prefix = f'toy2d_two_{agg}_BF_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax':
            prefix = f'toy2d_two_{agg}_MinMax_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'IPM':
            prefix = f'toy2d_two_{agg}_IPM_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum':
            prefix = f'toy2d_two_{agg}_MinSum_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'Mimic':
            prefix = f'toy2d_two_{agg}_mimic_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'ALIE':
            prefix = f'toy2d_two_{agg}_ALIE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        grads, losses, points = read_data_from_files(prefix)

        ax = fig.add_subplot(gs[idx])
        print(method, idx)
        if idx % 3 != 0:  # Hide y-axis if it's not the leftmost plot
            ax.get_yaxis().set_visible(False)
        if idx < 3:  # Hide x-axis if it's not the bottom plots
            ax.get_xaxis().set_visible(False)
        if method == 'BF':
            method = 'SF'
        if method == 'IPM':
            method = 'FOE'
        ax.set_title(method, fontsize=18)
        subplot_function_and_trajectory(ax, loss_func, points, grads, colors[idx], idx, method, agg,
                                        xlim=(-2, 2.5),
                                        ylim=(-1.5, 1.5))
        all_losses.append(losses[:, 1])
        idx += 1

    line_styles = {"No Attack": 'solid',
                   'Solve $(\mathrm{P})$': 'dashed',
                   "$\\textsc{Jump}$-100": 'dashdot',
                   "SF": (0, (3, 1, 1, 1, 1, 1)),
                   "$\\textsc{Jump}$-10": (0, (1, 0.8)),
                   "ALIE": (0, (1, 0.8)),
                   "MinMax ($-\overline{g}^{\mathcal{H}}_t$)": (0, (3, 1, 1, 1)),
                   "MinSum ($-\overline{g}^{\mathcal{H}}_t$)": (0, (3, 1, 1, 1)),
                   "Mimic": 'dashdot',
                   "FOE": (0, (3, 1, 1, 1, 1, 1)),
                   "$\\textsc{Jump}$-1": 'solid'
                   }

    methods.insert(0, 'No Attack')
    for i, method in enumerate(methods):
        if method == 'BF':
            methods[i] = 'SF'
        if method == 'IPM':
            methods[i] = 'FOE'
    ax = plt.subplot(2, 3, 6)
    for color, method, losses in zip(colors, methods, all_losses):
        ax.plot(losses, linestyle=line_styles[method], color=color)

    all_loss_values = [loss for losses in all_losses for loss in losses]
    min_loss = min(all_loss_values)
    max_loss = max(all_loss_values)

    ax.set_title('Loss', fontsize=18)
    ax.set_xlabel('Iterations', fontsize=18)
    ax.tick_params(axis='both', labelsize=18)
    # ax.set_ylabel('Loss')
    # ax.xaxis.set_ticks(np.arange(0, len(losses), 10))

    if args.agg == 'krum':
        legend = ax.legend(methods, loc='lower center', fontsize=13, ncol=2, columnspacing=1, handlelength=1.5, framealpha=0.5)
    else:
        legend = ax.legend(methods, loc='lower right', fontsize=13, ncol=2, columnspacing=1, handlelength=1.5, framealpha=0.5)

    for handle in legend.legendHandles:
        handle.set_linewidth(3)

    # Add a colorbar for the entire grid
    # cbar = fig.colorbar(contour_obj, ax=fig.get_axes(), shrink=1)
    # cbar.locator = MaxNLocator(nbins=7)
    # cbar.update_ticks()

    plt.tight_layout()
    # plt.show()

    fig_dir = EXP_DIR + 'images_2dnonconvex/' + f'toy2d_two_{agg}.pdf'
    fig.savefig(fig_dir, format='pdf')
    print('{} saved.'.format(fig_dir))

    ################################################################################3
    methods = ['BF', 'MinMax ($-\overline{g}^{\mathcal{H}}_t$)', 'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)',
               'ALIE', 'Mimic']
    colors = ['orange', 'r', 'g', 'c', 'm', 'y', 'b', 'k', 'blue']
    fig = plt.figure(figsize=(15, 7))

    gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 1], wspace=0.15, hspace=0.15)

    all_losses = []

    prefix = f'toy2d_two_avg_NA_niidFalse_n{n - f}_f0_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
    grads, losses, points = read_data_from_files(prefix)
    all_losses.append(losses[:, 1])

    idx = 0
    for method in methods:
        # if idx == 3:  # Skip the center plot for loss graph
        #     idx += 1
        if method == '$\\textsc{Jump}$-100':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        if method == 'Solve $(\mathrm{P})$':
            prefix = f'toy2d_two_{agg}_NLP_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-10':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize10_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-1':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize1_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'BF':
            prefix = f'toy2d_two_{agg}_BF_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\overline{g}^{\mathcal{H}}_t$)':
            prefix = f'toy2d_two_{agg}_MinMax_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'IPM':
            prefix = f'toy2d_two_{agg}_IPM_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\overline{g}^{\mathcal{H}}_t$)':
            prefix = f'toy2d_two_{agg}_MinSum_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'Mimic':
            prefix = f'toy2d_two_{agg}_mimic_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'ALIE':
            prefix = f'toy2d_two_{agg}_ALIE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        grads, losses, points = read_data_from_files(prefix)

        ax = fig.add_subplot(gs[idx])
        print(method, idx)
        if idx % 3 != 0:  # Hide y-axis if it's not the leftmost plot
            ax.get_yaxis().set_visible(False)
        if idx < 3:  # Hide x-axis if it's not the bottom plots
            ax.get_xaxis().set_visible(False)
        if method == 'BF':
            method = 'SF'
        if method == 'IPM':
            method = 'FOE'
        ax.set_title(method, fontsize=18)
        subplot_function_and_trajectory(ax, loss_func, points, grads, colors[idx], idx, method, agg,
                                        xlim=(-2, 2.5),
                                        ylim=(-1.5, 1.5))
        all_losses.append(losses[:, 1])
        idx += 1

    prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize1_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
    grads, losses, points = read_data_from_files(prefix)
    all_losses.append(losses[:, 1])

    line_styles = {"No Attack": 'solid',
                   'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': 'dashed',
                   "$\\textsc{Jump}$-100": 'dashdot',
                   "SF": (0, (3, 1, 1, 1, 1, 1)),
                   'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': (0, (1, 0.8)),
                   "ALIE": (0, (1, 0.8)),
                   "MinMax ($-\overline{g}^{\mathcal{H}}_t$)": (0, (3, 1, 1, 1)),
                   "Mimic": 'dashdot',
                   "$\\textsc{Jump}$-1": 'solid'
                   }

    for i, method in enumerate(methods):
        if method == 'BF':
            methods[i] = 'SF'
    ax = plt.subplot(2, 3, 6)
    # methods.insert(0, "No Attack")
    na, = ax.plot(all_losses.pop(0), linestyle=line_styles["No Attack"], color=colors.pop(0), label="No Attack")
    # methods.append('$\\textsc{Jump}$-1')
    jump1 = all_losses.pop()
    # j1, = ax.plot(all_losses.pop(), linestyle=line_styles['$\\textsc{Jump}$-1'], color='b', label='$\\textsc{Jump}$-1')
    lgd = []
    for color, method, losses in zip(colors, methods, all_losses):
        lgd.append(ax.plot(losses, linestyle=line_styles[method], color=color, label=method)[0])
    j1, = ax.plot(jump1, linestyle=line_styles['$\\textsc{Jump}$-1'], color='b', label='$\\textsc{Jump}$-1')

    # all_loss_values = [loss for losses in all_losses for loss in losses]
    # min_loss = min(all_loss_values)
    # max_loss = max(all_loss_values)

    ax.set_title('Loss', fontsize=18)
    ax.set_xlabel('Iterations', fontsize=18)
    # ax.set_ylabel('Loss')
    ax.tick_params(axis='both', labelsize=18)
    # ax.xaxis.set_ticks(np.arange(0, len(losses), 10))

    legend1 = ax.legend(handles=[na, j1], loc='upper center', fontsize=13, ncol=2, columnspacing=1,
                        handlelength=1.5, framealpha=0.5, bbox_to_anchor=(0.5, 1.03))
    ax.add_artist(legend1)
    legend2 = ax.legend(handles=lgd, loc='lower right', fontsize=13, ncol=2, columnspacing=1.1,
                        handlelength=1.5, framealpha=0.5, bbox_to_anchor=(1, 0.06))

    # if args.agg == 'rfa':
    #     legend = ax.legend(methods, loc='lower center', fontsize=13, ncol=2, columnspacing=1, handlelength=1.5)
    # else:
    #     legend = ax.legend(methods, loc='lower right', fontsize=13, ncol=2, columnspacing=1, handlelength=1.5)

    for handle in legend1.legendHandles:
        handle.set_linewidth(3)
    for handle in legend2.legendHandles:
        handle.set_linewidth(3)

    # Add a colorbar for the entire grid
    # cbar = fig.colorbar(contour_obj, ax=fig.get_axes(), shrink=1)
    # cbar.locator = MaxNLocator(nbins=7)
    # cbar.update_ticks()

    plt.tight_layout()
    # plt.show()

    fig_dir = EXP_DIR + 'images_2dnonconvex/' + f'toy2d_two_{agg}_compare.pdf'
    fig.savefig(fig_dir, format='pdf')
    print('{} saved.'.format(fig_dir))

    ################################################################################3
    methods = ['IPM', 'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)', 'MinSum ($-\overline{g}^{\mathcal{H}}_t$)',
               'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)', 'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)']
    colors = ['orange', 'r', 'g', 'c', 'm', 'y', 'b', 'k', 'blue']
    fig = plt.figure(figsize=(15, 7))

    gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 1], wspace=0.15, hspace=0.15)

    all_losses = []

    prefix = f'toy2d_two_avg_NA_niidFalse_n{n - f}_f0_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
    grads, losses, points = read_data_from_files(prefix)
    all_losses.append(losses[:, 1])

    idx = 0
    for method in methods:
        # if idx == 3:  # Skip the center plot for loss graph
        #     idx += 1
        if method == '$\\textsc{Jump}$-100':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        if method == 'Solve $(\mathrm{P})$':
            prefix = f'toy2d_two_{agg}_NLP_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize100_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-10':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize10_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == '$\\textsc{Jump}$-1':
            prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize1_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'BF':
            prefix = f'toy2d_two_{agg}_BF_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\overline{g}^{\mathcal{H}}_t$)':
            prefix = f'toy2d_two_{agg}_MinMax_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'IPM':
            prefix = f'toy2d_two_{agg}_IPM_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\overline{g}^{\mathcal{H}}_t$)':
            prefix = f'toy2d_two_{agg}_MinSum_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'Mimic':
            prefix = f'toy2d_two_{agg}_mimic_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'ALIE':
            prefix = f'toy2d_two_{agg}_ALIE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinMax_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_std_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        elif method == 'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)':
            prefix = f'toy2d_two_{agg}_MinSum_sign_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize0_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
        grads, losses, points = read_data_from_files(prefix)

        ax = fig.add_subplot(gs[idx])
        print(method, idx)
        if idx % 3 != 0:  # Hide y-axis if it's not the leftmost plot
            ax.get_yaxis().set_visible(False)
        if idx < 3:  # Hide x-axis if it's not the bottom plots
            ax.get_xaxis().set_visible(False)
        if method == 'BF':
            method = 'SF'
        if method == 'IPM':
            method = 'FOE'
        ax.set_title(method, fontsize=16)
        subplot_function_and_trajectory(ax, loss_func, points, grads, colors[idx], idx, method, agg,
                                        xlim=(-2, 2.5),
                                        ylim=(-1.5, 1.5))
        all_losses.append(losses[:, 1])
        idx += 1

    prefix = f'toy2d_two_{agg}_NOBLE_niidFalse_n{n}_f{f}_nlpobj1.0_nlpsize1_initp{initx}_{inity}_lr{lr}_iter{epoch}_seed0'
    grads, losses, points = read_data_from_files(prefix)
    all_losses.append(losses[:, 1])

    line_styles = {"No Attack": 'solid',
                   'MinSum ($-\mathrm{std}{(g^{\mathcal{H}}_t)}$)': 'dashed',
                   'MinSum ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': 'dashdot',
                   # "SF": (0, (3, 1, 1, 1, 1, 1)),
                   'MinMax ($-\mathrm{sign}{(g^{\mathcal{H}}_t)}$)': (0, (1, 0.8)),
                   "ALIE": (0, (1, 0.8)),
                   "MinSum ($-\overline{g}^{\mathcal{H}}_t$)": (0, (3, 1, 1, 1)),
                   # "Mimic": 'dashdot',
                   "FOE": (0, (3, 1, 1, 1, 1, 1)),
                   "$\\textsc{Jump}$-1": 'solid'
                   }

    methods.insert(0, 'No Attack')
    methods.append('$\\textsc{Jump}$-1')
    for i, method in enumerate(methods):
        if method == 'BF':
            methods[i] = 'SF'
        if method == 'IPM':
            methods[i] = 'FOE'
    ax = plt.subplot(2, 3, 6)
    for color, method, losses in zip(colors, methods, all_losses):
        ax.plot(losses, linestyle=line_styles[method], color=color)

    all_loss_values = [loss for losses in all_losses for loss in losses]
    min_loss = min(all_loss_values)
    max_loss = max(all_loss_values)

    ax.set_title('Loss', fontsize=16)
    ax.set_xlabel('Iterations', fontsize=16)
    # ax.set_ylabel('Loss')
    ax.tick_params(axis='both', labelsize=16)
    # ax.xaxis.set_ticks(np.arange(0, len(losses), 10))

    if args.agg == 'rfa':
        ax.legend(methods, loc='center right', fontsize=8)
    else:
        ax.legend(methods, loc='lower center', fontsize=8)

    # Add a colorbar for the entire grid
    # cbar = fig.colorbar(contour_obj, ax=fig.get_axes(), shrink=1)
    # cbar.locator = MaxNLocator(nbins=7)
    # cbar.update_ticks()

    plt.tight_layout()
    # plt.show()

    fig_dir = EXP_DIR + 'images_2dnonconvex/' + f'toy2d_two_{agg}_sota.pdf'
    fig.savefig(fig_dir, format='pdf')
    print('{} saved.'.format(fig_dir))
