import argparse
import numpy as np
from matplotlib import pyplot as plt


def setup_arguments_parser():
    parser = argparse.ArgumentParser(description='Doubly Accelerated Federated Learning '
                                                 'with Local Training and Compressed Communication')
    parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
    parser.add_argument('--n_workers', type=int, default=3000, help='Number of workers/nodes')
    parser.add_argument('--it_local', type=int, default=None)
    parser.add_argument('--it_max', type=int, default=10000)
    parser.add_argument('--prob', type=float, default=None, help='Probability of server communication')
    parser.add_argument('--s', type=int, default=None, help='Number of ones per row in compression pattern')
    parser.add_argument('--eta', type=float, default=None)
    parser.add_argument('--dataset', type=str, default='w8a', help='[w8a, a9a, colon-cancer, duke, real-sim]')
    parser.add_argument('--lr', type=float, default=None, help='Learning rate')
    parser.add_argument('--c_uplink', type=float, default=1, help='Cost of sending one real number to the server')
    parser.add_argument('--c_downlink', type=float, default=0.0,
                        help='Cost of receiving one real number from the server')
    parser.add_argument('--reg', type=float, default='0.003')
    parser.add_argument('--pbars', type=str, default='steps', help='[all, steps, none]')
    parser.add_argument('--load_dir', type=str, default=None, help='Directory with saved runs, None for new runs')
    return parser.parse_args()


def trim_dataset(a, b, n_workers: int):
    num_rows, _ = a.shape
    if num_rows % n_workers != 0:
        a = a[:num_rows - (num_rows % n_workers)]
        b = b[:num_rows - (num_rows % n_workers)]
    return a, b


def normalize(b):
    b_unique = np.unique(b)
    if (b_unique == [1, 2]).all():
        # Transform labels {1, 2} to {0, 1}
        b = b - 1
    elif (b_unique == [-1, 1]).all():
        # Transform labels {-1, 1} to {0, 1}
        b = (b + 1) / 2
    else:
        # Replace class labels with 0's and 1's
        b = 1. * (b == b[0])
    return b


def normalise_and_trim(a, b, num_cpus: int):
    a, b = trim_dataset(a, b, num_cpus)
    b = normalize(b)
    return a, b


def run(runs, loss, worker_losses, dim, c_downlink, c_uplink, save_dir):
    traces = []
    labels = []
    communications = []
    f_opt = np.inf
    for run in runs:
        is_saved_run = len(run) == 2

        if is_saved_run:
            trace, uplink_communicated_numbers, downlink_communicated_numbers, alg_f_opt = run[0]
            label = run[1]
        else:
            x0 = np.zeros(dim)
            alg = run[0]
            label = run[1]
            alg_args = run[2]
            trace, uplink_communicated_numbers, downlink_communicated_numbers, alg_f_opt = \
                alg(x0, loss, worker_losses, alg_args)

        traces.append(trace)
        labels.append(label)
        communications.append(c_uplink * np.asarray(uplink_communicated_numbers) +
                              c_downlink * np.asarray(downlink_communicated_numbers))

        if alg_f_opt < f_opt:
            f_opt = alg_f_opt
    plot_total_com(communications, labels, traces, f_opt, save_dir)


def plot_total_com(communications, labels, traces, f_opt, save_dir):
    plt.rcParams['xtick.labelsize'] = 20
    plt.rcParams['ytick.labelsize'] = 20
    plt.rcParams['legend.fontsize'] = 20
    plt.rcParams['axes.titlesize'] = 22
    plt.rcParams['axes.labelsize'] = 22
    plt.rcParams["figure.figsize"] = [15, 10]
    plt.yscale('log')
    plt.ylabel(r'$f(x) - f*$')

    plt.xlabel('TotalCom')
    # plt.ylim(bottom=1 * 10e-5, top=1)
    # plt.xlim(left=-5000, right=50000)
    plt.grid()

    # all_markers = ['s', '^', 'D', 'o', 'v', '*']

    for i, trace in enumerate(traces):
        plt.plot(communications[i], list(trace.loss_vals_all.values())[0] - f_opt, label=labels[i])
        # markevery=markevery[i], marker=all_markers[i])

    plt.legend()
    plt.savefig(f'{save_dir}/total_com.png')
    # plt.show()
    plt.close()
