import os
import pickle

import json
import re
import numpy as np
from matplotlib import pyplot as plt


def trim_dataset(a, b, n_workers: int):
    num_rows, _ = a.shape
    if num_rows % n_workers != 0:
        a = a[:num_rows - (num_rows % n_workers)]
        b = b[:num_rows - (num_rows % n_workers)]
    return a, b


def normalize(b):
    b_unique = np.unique(b)
    if (b_unique == [1, 2]).all():
        # Transform labels {1, 2} to {0, 1}
        b = b - 1
    elif (b_unique == [-1, 1]).all():
        # Transform labels {-1, 1} to {0, 1}
        b = (b + 1) / 2
    else:
        # Replace class labels with 0's and 1's
        b = 1. * (b == b[0])
    return b


def normalise_and_trim(a, b, num_cpus: int):
    a, b = trim_dataset(a, b, num_cpus)
    b = normalize(b)
    return a, b


def load_runs(args):
    load_dir = args['load_directory']

    if "w1a_n87" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-', 'LoCoDL: Rand-4'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-', 'LoCoDL: Rand-4 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-75', 'ADIANA: Rand-75'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-75', 'ADIANA: Rand-75 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "w1a_n619" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-75', 'ADIANA: Rand-75'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-75', 'ADIANA: Rand-75 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "a5a_n87" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-30', 'ADIANA: Rand-30'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-30', 'ADIANA: Rand-30 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "a5a_n288" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-20', 'ADIANA: Rand-20'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-20', 'ADIANA: Rand-20 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]
    elif "diabetes_n6" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]   
    elif "australian_n9" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-2', 'LoCoDL: Rand-2'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-2', 'LoCoDL: Rand-2 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-3', 'ADIANA: Rand-3'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-3', 'ADIANA: Rand-3 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]   
    elif "australian_n41" in load_dir or "australian_n225" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-3', 'ADIANA: Rand-3'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-3', 'ADIANA: Rand-3 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ] 
    elif "diabetes_n37" in load_dir or "breast-cancer_n29" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ] 
    elif "diabetes_n73" in load_dir:
        prefixes_labels = [
            ('LoCoDL: Rand-1', 'LoCoDL: Rand-1'),
            ('LoCoDL: Natural.', 'LoCoDL: Natural'),
            ('LoCoDL: Natural + Rand-1', 'LoCoDL: Rand-1 + Natural'),
            ('LoCoDL: Sign-1', 'LoCoDL: ' + r'$l_1$-select'),

            ('ADIANA: Rand-2', 'ADIANA: Rand-2'),
            ('ADIANA: Natural.', 'ADIANA: Natural'),
            ('ADIANA: Natural + Rand-2', 'ADIANA: Rand-2 + Natural'),
            ('ADIANA: Sign-1', 'ADIANA: ' + r'$l_1$-select'),

            ('DIANA: Rand-1', 'DIANA: Rand-1'),
            ('DIANA: Natural.', 'DIANA: Natural'),
            ('DIANA: Natural + Rand-1', 'DIANA: Rand-1 + Natural'),
            ('DIANA: Sign-1', 'DIANA: ' + r'$l_1$-select'),

            ('5GCS-CC: Rand-1', '5GCS-CC: Rand-1'),
            ('5GCS-CC: Natural.', '5GCS-CC: Natural',),
            ('5GCS-CC: Natural + Rand-1', '5GCS-CC: Rand-1 + Natural'),
            ('5GCS-CC: Sign-1', '5GCS-CC: ' + r'$l_1$-select'),

            ('CompressedScaffnew', 'CompressedScaffnew: s=2'),
            ('gradskip', 'GradSkip'),
            ('scaffold', 'Scaffold'),
        ]

    colors = [
        '#d62728',
        '#d62728',
        '#d62728',
        '#d62728',

        '#1f77b4',
        '#1f77b4',
        '#1f77b4',
        '#1f77b4',

        '#ff7f0e',
        '#ff7f0e',
        '#ff7f0e',
        '#ff7f0e',


        '#2ca02c',
        '#2ca02c',
        '#2ca02c',
        '#2ca02c',

        '#9467bd',
        '#8c564b',
        '#e377c2',
    ]

    markers = [
        'D',
        'D',
        'D',
        'x',

        'D',
        'D',
        'D',
        'x',

        'D',
        'D',
        'D',
        'x',

        'D',
        'D',
        'D',
        'x',

        'p',
        '*',
        's',
    ]

    fillstyles = [
        'right',
        'left',
        'none',
        'none',

        'right',
        'left',
        'none',
        'none',

        'right',
        'left',
        'none',
        'full',

        'right',
        'left',
        'none',
        'none',

        'top',
        'full',
        'full',
    ]

    prefixes, labels = zip(*prefixes_labels)

    trace_groups = []
    for prefix in prefixes:
        trace_groups.append(load_alg(prefix=prefix, load_dir=load_dir))

    with open(f'{load_dir}/f_star.txt', 'r') as f:
        f_star = float(f.readline())

    plot_total_com(labels=labels, trace_groups=trace_groups, f_opt=f_star,
                   downlink_factor=args['downlink_factor'], uplink_factor=args['uplink_factor'], save_dir=load_dir,
                   plt_title=f"Dataset {args['dataset']}, {args['n_workers']} workers, ",
                   colors=colors, markers=markers, fillstyles=fillstyles)


def load_alg(prefix, load_dir):
    alg_files = [filename for filename in os.listdir(
        load_dir) if filename.startswith(prefix)]
    alg_runs = []
    alg_uplink_cost = None
    alg_downlink_cost = None
    for i in range(len(alg_files)):
        with open(f'{load_dir}/{alg_files[i]}', 'rb') as f:
            trace, alg_uplink_cost, alg_downlink_cost = pickle.load(f)
            alg_runs.append(trace)
    if len(alg_runs) == 0:
        print(f"Couldn't load files for prefix: {prefix}")
        exit()
    return alg_runs, alg_uplink_cost, alg_downlink_cost


def run(runs, loss, worker_losses, dim, downlink_factor, uplink_factor, save_dir, f_star, n_repeats=1, threshold=2e-6, plt_title=None, **kwargs):
    trace_groups = []
    labels = []

    for r in runs:
        x0 = np.zeros(dim, dtype=np.float32)
        alg, alg_args, index = r

        trace_group = []
        alg_uplink_cost = None
        alg_downlink_cost = None

        for i in range(n_repeats):
            trace, label, alg_uplink_cost, alg_downlink_cost = alg(x0=x0, loss=loss, worker_losses=worker_losses,
                                                                   args=alg_args, index=index, threshold=threshold, **kwargs)
            trace_group.append(trace)

        trace_groups.append((trace_group, alg_uplink_cost, alg_downlink_cost))
        labels.append(label)

    plot_total_com(labels=labels, trace_groups=trace_groups, f_opt=f_star, downlink_factor=downlink_factor,
                   uplink_factor=uplink_factor, save_dir=save_dir, plt_title=plt_title)


def plot_total_com(labels, trace_groups, f_opt, downlink_factor, uplink_factor, save_dir, plt_title=None, colors=None, markers=None, fillstyles=None, linestyles=None):
    print('Plotting...')

    plt.rcParams['xtick.labelsize'] = 20
    plt.rcParams['ytick.labelsize'] = 20
    plt.rcParams['legend.fontsize'] = 13
    plt.rcParams['axes.titlesize'] = 22
    plt.rcParams['axes.labelsize'] = 22
    plt.rcParams["figure.figsize"] = [15, 10]
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.yscale('log')
    plt.ylabel(r'$f(x) - f*$')

    plt.xlabel('Number of Communicated Bits')
    if "australian" in save_dir:
        plt.ylim(bottom=4e-6, top=1)
    else:
        plt.ylim(bottom=2e-6, top=1)
    plt.grid()

    if "a5a" in save_dir or "w1a" in save_dir:
        x_lim = 2e6
    elif "diabetes" in save_dir:
        x_lim = 3e5
    elif "australian" in save_dir:
        x_lim = 3e5

    plt.xlim(0, x_lim)

    all_markers = ['s', '^', 'D', 'o', 'v', '*', 'd', 'X', 'h', 'H', None]
    min_it = np.inf
    best = ''
    for i, trace_group in enumerate(trace_groups):
        traces_in_group, alg_uplink_cost, alg_downlink_cost = trace_group
        iterations = min([len(trace.its) for trace in traces_in_group])

        x_axis = np.arange(iterations) * (uplink_factor *
                                          alg_uplink_cost + downlink_factor * alg_downlink_cost) * 32  # TODO change?
        if (uplink_factor == 0) and (downlink_factor == 0):
            x_axis = np.arange(iterations)
            print(iterations)

        x_axis = x_axis[x_axis < x_lim]
        iterations = len(x_axis)
        loss_values_in_group = []
        for t in traces_in_group:
            loss_values = list(t.loss_vals_all.values())[
                0][:iterations] - f_opt
            loss_values_in_group.append(loss_values)

        lowest = np.min(loss_values_in_group, axis=0)
        largest = np.max(loss_values_in_group, axis=0)

        if loss_values_in_group[0][-1] < 2e-4 and x_axis[-1] < min_it:
            min_it = x_axis[-1]
            best = labels[i]

        if markers is not None:
            marker = markers[i]
        else:
            marker = all_markers[i % len(all_markers)]

        if colors is not None:
            color = colors[i]
        else:
            color = None

        if linestyles is not None:
            linestyle = linestyles[i]
        else:
            linestyle = None

        if fillstyles is not None:
            fillstyle = fillstyles[i]
        else:
            fillstyle = None

        if "a5a" in save_dir or "diabetes" in save_dir or "breast-cancer" in save_dir:
            if "LoCoDL" in labels[i] or 'CompressedScaffnew' in labels[i]:
                mean_loss_values = np.mean(loss_values_in_group, axis=0)
                window_size = 20
                window = np.ones(window_size) / window_size
                # Apply the moving average filter
                smoothed_values = np.convolve(
                    mean_loss_values, window, mode='same')
            else:
                smoothed_values = np.mean(loss_values_in_group, axis=0)
        else:
            smoothed_values = np.mean(loss_values_in_group, axis=0)

        if 'select' in labels[i]:
            iterations *= 0.7 # this is for markers to not overlap

        plt.plot(x_axis, smoothed_values, lw=2, label=labels[i],
                 markevery=int(iterations//10), marker=marker, markersize=12,
                 color=color, linestyle=linestyle, fillstyle=fillstyle,
                #  alpha=0.5
                 )
        # plt.fill_between(x_axis, largest, lowest, alpha=0.4, color=color)
        # plt.show()

    plt_title = None
    plt.legend(loc='upper right')
    if plt_title is not None:
        plt.title(plt_title)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/total_com_{labels[0]}.pdf')
    print(f'Plot saved to: {save_dir}/total_com.pdf')
    # plt.show()
    plt.close()

    print(best)
