import os
import pickle

import json
import re
import numpy as np
from matplotlib import pyplot as plt


def trim_dataset(a, b, n_workers: int):
    num_rows, _ = a.shape
    if num_rows % n_workers != 0:
        a = a[:num_rows - (num_rows % n_workers)]
        b = b[:num_rows - (num_rows % n_workers)]
    return a, b


def normalize(b):
    b_unique = np.unique(b)
    print(b)
    print(b_unique)
    if (b_unique == [1, 2]).all():
        # Transform labels {1, 2} to {0, 1}
        b = b - 1
    elif (b_unique == [-1, 1]).all():
        # Transform labels {-1, 1} to {0, 1}
        b = (b + 1) / 2
    else:
        # Replace class labels with 0's and 1's
        b = 1. * (b == b[0])
    return b


def normalise_and_trim(a, b, num_cpus: int):
    a, b = trim_dataset(a, b, num_cpus)
    b = normalize(b)
    return a, b


def load_runs(args):
    load_dir = args['load_directory']

    colors = [
        "#C00000",
        "#B8860B",
        "#8B008B",
        "#006666",
    ]

    # colors = {
    # "ATA": "#C00000",  # Cherry Red
    # "ATA: Empirical": "#B8860B",  # Dark Golden Orange
    # "FTA: Optimal": "#006666",  # Dark Teal
    # "GTA": "#36454F",  # Charcoal Gray
    # "UTA": "#8B2500"  # Deep Rust
    # }

    markers = [
        '*',
        'd',
        '^',
        'H',
    ]

    fillstyles = [
        'full',
        'full',
        'full',
        'full',
    ]



    if "w1a_n87" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-', 'LoCoDL: Rand-4'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-', 'LoCoDL: Rand-4 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-75', 'ADIANA: Rand-75'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-75', 'ADIANA: Rand-75 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "w1a_n619" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-75', 'ADIANA: Rand-75'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-75', 'ADIANA: Rand-75 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "a5a_n87" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-30', 'ADIANA: Rand-30'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-30', 'ADIANA: Rand-30 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "a5a_n288" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-20', 'ADIANA: Rand-20'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-20', 'ADIANA: Rand-20 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "diabetes_n6" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            # ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            # ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            # ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            # ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            # ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            # ('DIANA: Rand-1', 'DIANA: Rand-1'),
            # ('DIANA: Natural.', 'DIANA: Natural'),
            # ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            # ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            # ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            # ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            # ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            # ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            # ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            # ('gradskip', 'GradSkip'),
            # ('scaffold', 'Scaffold'),
        ]   
    elif "australian_n9" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-3', 'ADIANA: Rand-3'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-3', 'ADIANA: Rand-3 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]   
    elif "australian_n41" in load_dir or "australian_n225" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-3', 'ADIANA: Rand-3'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-3', 'ADIANA: Rand-3 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ] 
    elif "diabetes_n37" in load_dir or "breast-cancer_n29" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ] 
    elif "diabetes_n73" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "dirichlet" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            # ('DIANA: Rand-1', 'DIANA: Rand-1'),
            # ('DIANA: Natural.', 'DIANA: Natural'),
            # ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            # ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            # ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            # ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            # ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            # ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            # ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            # ('gradskip', 'GradSkip'),
            # ('scaffold', 'Scaffold'),
        ]
    elif "mnist" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-131', 'LoCoDL: Rand-131'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-131', 'LoCoDL: Rand-131 + Natural'),
            # ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            # ('ADIANA: Rand-196', 'ADIANA: Rand-196'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            # ('ADIANA: Natural + Rand-196', 'ADIANA: Rand-196 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            # ('DIANA: Rand-1', 'DIANA: Rand-1'),
            # ('DIANA: Natural.', 'DIANA: Natural'),
            # ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            # ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            # ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            # ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            # ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            # ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            # ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            # ('gradskip', 'GradSkip'),
            # ('scaffold', 'Scaffold'),
        ]
    elif 'real-sim_n10_' in load_dir:
        prefixes_labels = [
            ('BiCoLoR: k=1000, Natural, gamma=512', 'BiCoLoR: k-spars.'),
            ('BiCoLoR: k=20958, Natural + Rand-1000, gamma=512', 'BiCoLoR: ind. rand-K'),
            ('2Direction: stepsize: 256', '2Direction'),
            ('EF21-P+DIANA: stepsize: 512', 'EF21-P+DIANA'),
        ]
    elif 'real-sim_n100_' in load_dir:
        prefixes_labels = [
            ('BiCoLoR: k=1000, Natural, gamma=512', 'BiCoLoR: k-spars.'),
            ('BiCoLoR: k=20958, Natural + Rand-1000, gamma=512', 'BiCoLoR: ind. rand-K'),
            ('2Direction: stepsize: 256', '2Direction'),
            ('EF21-P+DIANA: stepsize: 512', 'EF21-P+DIANA'),
        ]
    elif 'w8a_n10_alpha1_c1_reg0_r1' in load_dir:
        prefixes_labels = [
            ('BiCoLoR: k=100, Natural, gamma=64, convex=True', 'BiCoLoR'),
            ('BiCoLoR: k=100, Natural, gamma=64, convex=False', r'BiCoLoR: const. $p$'),
            ('2Direction: stepsize: 32', '2Direction'),
            ('EF21-P+DIANA: stepsize: 256', 'EF21-P+DIANA'),
        ]
    elif 'w8a_n10_' in load_dir:
        prefixes_labels = [
            ('BiCoLoR: k=100, Natural', 'BiCoLoR: k-spars.'),
            ('BiCoLoR: k=300, Natural + Rand-100, gamma=16', 'BiCoLoR: ind. rand-K'),
            ('2Direction: stepsize: 32', '2Direction'),
            ('EF21-P+DIANA: stepsize: 128', 'EF21-P+DIANA'),
        ]
    elif 'w8a_n100_' in load_dir:
        prefixes_labels = [
            ('BiCoLoR: k=100, Natural, gamma=8', 'BiCoLoR: k-spars.'),
            ('BiCoLoR: k=300, Natural + Rand-100, gamma=8', 'BiCoLoR: ind. rand-K'),
            ('2Direction: stepsize: 32', '2Direction'),
            ('EF21-P+DIANA: stepsize: 256', 'EF21-P+DIANA'),
        ]
    else:
        prefixes_labels = [
            ('BiCoLoR: k=1000, Natural, gamma=128', 'BiCoLoR: stepsize 128'),
            ('BiCoLoR: k=1000, Natural, gamma=256', 'BiCoLoR: stepsize 256'),
            ('BiCoLoR: k=1000, Natural, gamma=512', 'BiCoLoR: stepsize 512'),
            ('2Direction: stepsize: 128', '2Direction: stepsize: 128'),
            ('2Direction: stepsize: 256', '2Direction: stepsize: 256'),
            ('2Direction: stepsize: 512', '2Direction: stepsize: 512'),
            # ('EF21-P+DIANA: stepsize: 128', 'EF21-P+DIANA: stepsize: 128'),
            ('EF21-P+DIANA: stepsize: 256', 'EF21-P+DIANA: stepsize: 256'),
            ('EF21-P+DIANA: stepsize: 512', 'EF21-P+DIANA: stepsize: 512'),
        ]
        colors, markers, fillstyles = None, None, None

    prefixes, labels = zip(*prefixes_labels)

    trace_groups = []
    for prefix in prefixes:
        trace_groups.append(load_alg(prefix=prefix, load_dir=load_dir))

    with open(f'{load_dir}/f_star.txt', 'r') as f:
        f_star = float(f.readline())

    plot_total_com(labels=labels, trace_groups=trace_groups, f_opt=f_star,
                   downlink_factor=args['downlink_factor'], uplink_factor=args['uplink_factor'], save_dir=load_dir,
                #    plt_title=f"Dataset {args['dataset']}, {args['n_workers']} workers, ",
                   colors=colors, markers=markers, fillstyles=fillstyles)


def load_alg(prefix, load_dir):
    alg_files = [filename for filename in os.listdir(
        load_dir) if filename.startswith(prefix)]
    alg_runs = []
    alg_uplink_cost = None
    alg_downlink_cost_group = []
    for i in range(len(alg_files)):
        with open(f'{load_dir}/{alg_files[i]}', 'rb') as f:
            print(alg_files[i])
            trace, alg_uplink_cost, alg_downlink_cost = pickle.load(f)
            alg_runs.append(trace)
            alg_downlink_cost_group.append(alg_downlink_cost)

    if len(alg_runs) == 0:
        print(f"Couldn't load files for prefix: {prefix}")
        exit()
    return alg_runs, alg_uplink_cost, alg_downlink_cost_group


def run(runs, loss, worker_losses, dim, downlink_factor, uplink_factor, save_dir, f_star, n_repeats=1, threshold=2e-6, plt_title=None, **kwargs):
    trace_groups = []
    labels = []

    for r in runs:
        x0 = np.zeros(dim, dtype=np.float32)
        alg, alg_args, index = r

        trace_group = []
        alg_uplink_cost = None
        alg_downlink_cost_group = []

        for i in range(n_repeats):
            trace, label, alg_uplink_cost, alg_downlink_cost = alg(x0=x0, loss=loss, worker_losses=worker_losses,
                                                                   args=alg_args, index=index, threshold=threshold, **kwargs)
            trace_group.append(trace)
            alg_downlink_cost_group.append(alg_downlink_cost)

        trace_groups.append((trace_group, alg_uplink_cost, alg_downlink_cost_group))
        labels.append(label)

    plot_total_com(labels=labels, trace_groups=trace_groups, f_opt=f_star, downlink_factor=downlink_factor,
                   uplink_factor=uplink_factor, save_dir=save_dir, plt_title=plt_title)


def plot_total_com(labels, trace_groups, f_opt, downlink_factor, uplink_factor, save_dir, plt_title=None, colors=None, markers=None, fillstyles=None, linestyles=None):
    print('Plotting...')

    plt.rcParams['xtick.labelsize'] = 40
    plt.rcParams['ytick.labelsize'] = 40
    plt.rcParams['legend.fontsize'] = 40
    plt.rcParams['axes.titlesize'] = 40
    plt.rcParams['axes.labelsize'] = 40
    plt.rcParams["figure.figsize"] = [16, 12]
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['text.usetex'] = True
    plt.yscale('log')
    plt.ylabel(r'$f(x) - f*$')

    plt.xlabel('Number of Communicated Bits')
    if "australian" in save_dir:
        plt.ylim(bottom=4e-6, top=1)

    plt.grid()

    x_lim = np.infty
    # x_lim = 1e8
    if "a5a" in save_dir or "w1a" in save_dir:
        x_lim = 2e6
    elif "diabetes" in save_dir:
        x_lim = 3e5
    elif "australian" in save_dir:
        x_lim = 3e5

    max_x_lim = -np.infty
    bottom_y_lim = np.infty
    top_y_lim = -np.infty

    all_markers = ['s', '^', 'D', 'o', 'v', '*', 'd', 'X', 'h', 'H', None]

    for i, trace_group in enumerate(trace_groups):
        traces_in_group, alg_uplink_cost, alg_downlink_cost_group = trace_group
        min_iterations = min([len(trace.its) for trace in traces_in_group])

        if (uplink_factor == 0) and (downlink_factor == 0):
            x_axis = np.arange(min_iterations)
            x_axis_group = [x_axis] * len(traces_in_group)
        else:
            if isinstance(alg_downlink_cost_group[0], np.ndarray):
                x_axis_group = []
                for alg_downlink_cost in alg_downlink_cost_group:
                    alg_downlink_cost = np.cumsum(alg_downlink_cost) * downlink_factor * 32
                    x_axis_group.append(np.arange(len(alg_downlink_cost)) * (uplink_factor * alg_uplink_cost) * 32 + alg_downlink_cost)
            else: 
                x_axis_group = [np.arange(min_iterations) * (uplink_factor * alg_uplink_cost + downlink_factor * alg_downlink_cost_group[0]) * 32] * len(traces_in_group) # change alg_downlink_cost_group to each run


        for x_axis in x_axis_group:
            x_axis = x_axis[x_axis < x_lim]

        loss_values_in_group = []
        # best_loss_value = np.inf
        # worst_loss_value = -np.inf
        for j, t in enumerate(traces_in_group):
            loss_values = list(t.loss_vals_all.values())[0][:len(x_axis_group[j])] - f_opt
            loss_values_in_group.append(loss_values)
            # if loss_values[-1] < best_loss_value:
            #     best_loss_value = loss_values[-1]
            #     best_curve = loss_values
            #     best_curve_index = j
            # if loss_values[-1] > worst_loss_value:
            #     worst_loss_value = loss_values[-1]
            #     worst_curve = loss_values
            #     worst_curve_index = j

        # Step 1: Choose a common x-axis (e.g., union of all x_axes)
        x_common = np.unique(np.concatenate([x for x in x_axis_group]))

        max_x_lim = x_common[-1] if x_common[-1] > max_x_lim else max_x_lim

        # Step 2: Interpolate all loss curves onto the common x-axis
        interpolated_curves = [
            np.interp(x_common, x_axis_group[j], curve)
            for j, curve in enumerate(loss_values_in_group)
        ]

        # Step 3: Compute the mean curve and std curve
        mean_curve = np.mean(interpolated_curves, axis=0)
        std_curve = np.std(interpolated_curves, axis=0)

        bottom_y_lim = mean_curve[-1] if mean_curve[-1] < bottom_y_lim else bottom_y_lim
        top_y_lim = mean_curve[0] if mean_curve[0] > top_y_lim else top_y_lim

        marker = markers[i] if markers is not None else all_markers[i % len(all_markers)]
        color = colors[i] if colors is not None else None
        linestyle = linestyles[i] if linestyles is not None else None
        fillstyle = fillstyles[i] if fillstyles is not None else None

        if "a5a" in save_dir or "diabetes" in save_dir or "breast-cancer" in save_dir:
            if "LoCoDL" in labels[i] or 'CompressedScaffnew' in labels[i]:
                window_size = 20
                window = np.ones(window_size) / window_size
                # Apply the moving average filter
                mean_curve = np.convolve(mean_curve, window, mode='same')
                std_curve = np.convolve(std_curve, window, mode='same')

        if 'select' in labels[i]:
            iterations *= 0.7 # this is for markers to not overlap

        plt.plot(x_common, mean_curve, lw=5, label=labels[i],
                 markevery=max(1, len(x_common) // 10), marker=marker, markersize=24 if marker != '*' else 28,
                 color=color, linestyle=linestyle, fillstyle=fillstyle,
                 alpha=0.7
                 )
        
        plt.fill_between(x_common, 
                         mean_curve - std_curve,
                         mean_curve + std_curve, 
                         alpha=0.4, color=color)

    x_lim = max_x_lim if x_lim == np.infty else x_lim
    plt.xlim(left=0, right=x_lim) 
    # bottom_y_lim = 3e-6
    # top_y_lim = 1e-2
    plt.ylim(bottom=bottom_y_lim, top=top_y_lim)

    plt.legend(loc='upper right')
    if plt_title is not None:
        plt.title(plt_title)

    # # Format the first tick as 0, others with 2 decimals
    # import matplotlib.ticker as mticker

    # def custom_formatter(x, pos):
    #     if abs(x) < 1e-8:
    #         return r'$0$'
    #     return r'${:.2f}$'.format(x)

    # plt.gca().xaxis.set_major_formatter(mticker.FuncFormatter(custom_formatter))

    plt.tight_layout()
    plt.savefig(f'{save_dir}/total_com_{labels[0]}.pdf')
    print(f'Plot saved to: {save_dir}/total_com.pdf')
    plt.close()