import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('fast')
mpl.rcParams['mathtext.fontset'] = 'cm'
# mpl.rcParams['mathtext.fontset'] = 'dejavusans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 2.0
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['axes.labelsize'] = 'xx-large'
# ratios = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
ratios = [0.5, 0.9, 1.0]

color_ar_1 = [u'#1f77b4', u'#ff7f0e', u'#2ca02c', u'#d62728', u'#9467bd', u'#8c564b', u'#e377c2', u'#7f7f7f',
              u'#bcbd22', u'#17becf']
markers = ['x', '.', '+', '1', 'p', '*', 'D', '.', 's']
fig, axs = plt.subplots(2, figsize=(7, 10), constrained_layout=True)


def single_plot(args, print_name, results):
    i = 0
    loss_file = f"./losses_{print_name}.txt"
    acc_file = f"./acc_{print_name}.txt"
    k_ratio = args.k_ratio
    # res_loss, res_acc = main(k_ratio)
    res_loss, res_acc = results
    rounds = np.arange(len(res_loss))
    # axs[0].plot(rounds, res_loss, marker=markers[i], markersize=10, color=color_ar_1[i], label=k_ratio)
    # axs[1].plot(rounds, res_acc, marker=markers[i], markersize=10, color=color_ar_1[i], label=k_ratio)
    axs[0].plot(rounds, res_loss, color=color_ar_1[i], label=k_ratio)
    axs[1].plot(rounds, res_acc, color=color_ar_1[i], label=k_ratio)
    with open(loss_file, 'w') as file:
        file.write(' '.join([str(loss) for loss in res_loss]))
    with open(acc_file, 'w') as file:
        file.write(' '.join([str(acc) for acc in res_acc]))

    axs[0].legend()
    axs[0].set_ylabel('Loss')
    axs[0].set_xlabel('Communication rounds')

    axs[1].legend()
    axs[1].set_ylabel('Accuracy')
    axs[1].set_xlabel('Communication rounds')

    plt.savefig(f"{print_name}.pdf")
    plt.close()