import pickle
import numpy as np
import matplotlib
from matplotlib.ticker import FormatStrFormatter
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

data_name_list = ['MNIST', 'KMNIST', 'FMNIST']  # name of dataset
n_rep = 5  # number of repetitions
rd_seed_list = np.arange(start=100, stop=100+n_rep)  # random seed list
n_commun = 1000  # number of rounds
M = 500 # number of clients

loss_list_Ada_OSMD = np.zeros((len(data_name_list), n_rep, n_commun))
accu_list_Ada_OSMD = np.zeros((len(data_name_list), n_rep, n_commun))
loss_list_uniform = np.zeros((len(data_name_list), n_rep, n_commun))
accu_list_uniform = np.zeros((len(data_name_list), n_rep, n_commun))
loss_list_MABS = np.zeros((len(data_name_list), n_rep, n_commun))
accu_list_MABS = np.zeros((len(data_name_list), n_rep, n_commun))
loss_list_VRB = np.zeros((len(data_name_list), n_rep, n_commun))
accu_list_VRB = np.zeros((len(data_name_list), n_rep, n_commun))
loss_list_Avare = np.zeros((len(data_name_list), n_rep, n_commun))
accu_list_Avare = np.zeros((len(data_name_list), n_rep, n_commun))

# load samples distribution info
train_samples_dist = np.zeros((len(data_name_list), M))
for i, data_name in enumerate(data_name_list):
    with open('result_logistic/train_samples_dist_' + str(data_name) +'.pickle', 'rb') as handle:
        train_samples_dist[i, :] = pickle.load(handle)

# load the results
for i, data_name in enumerate(data_name_list):
    for j in range(n_rep):
        rd_seed = rd_seed_list[j]
        with open('result_logistic/loss_list_uniform_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            loss_list_uniform[i, j, :] = pickle.load(handle)

        with open('result_logistic/accu_list_uniform_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            accu_list_uniform[i, j, :] = pickle.load(handle)
        
        with open('result_logistic/loss_list_Ada_OSMD_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            loss_list_Ada_OSMD[i, j, :] = pickle.load(handle)

        with open('result_logistic/accu_list_Ada_OSMD_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            accu_list_Ada_OSMD[i, j, :] = pickle.load(handle)

        with open('result_logistic/loss_list_MABS_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            loss_list_MABS[i, j, :] = pickle.load(handle)

        with open('result_logistic/accu_list_MABS_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            accu_list_MABS[i, j, :] = pickle.load(handle)

        with open('result_logistic/loss_list_VRB_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            loss_list_VRB[i, j, :] = pickle.load(handle)

        with open('result_logistic/accu_list_VRB_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            accu_list_VRB[i, j, :] = pickle.load(handle)

        with open('result_logistic/loss_list_Avare_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            loss_list_Avare[i, j, :] = pickle.load(handle)

        with open('result_logistic/accu_list_Avare_' + data_name + '_rdseed=' + str(rd_seed) +'.pickle', 'rb') as handle:
            accu_list_Avare[i, j, :] = pickle.load(handle)

loss_list_uniform_logmean = np.log(loss_list_uniform).mean(1)
accu_list_uniform = accu_list_uniform.mean(1)
loss_list_Ada_OSMD_logmean = np.log(loss_list_Ada_OSMD).mean(1)
accu_list_Ada_OSMD = accu_list_Ada_OSMD.mean(1)
loss_list_MABS_logmean = np.log(loss_list_MABS).mean(1)
accu_list_MABS = accu_list_MABS.mean(1)
loss_list_VRB_logmean = np.log(loss_list_VRB).mean(1)
accu_list_VRB = accu_list_VRB.mean(1)
loss_list_Avare_logmean = np.log(loss_list_Avare).mean(1)
accu_list_Avare = accu_list_Avare.mean(1)

# plot
std_plot = 1.0

# plot samples distribution info
fig, ax = plt.subplots(1, 1, figsize=[20., 8.])
ax.hist(train_samples_dist[0, :], bins=100, density=False)
ax.set_xlabel('number of samples', fontsize=40)
ax.set_ylabel('number of users', fontsize=40)
ax.set_title('Sample size distribution', fontsize=40)
ax.tick_params(axis='both', which='major', labelsize=40)
plt.savefig('plots_logistic/samples_distribution.png', dpi=400, bbox_inches='tight')
plt.close(fig)

# plot loss
fig, axes = plt.subplots(1, 3, figsize=[32., 12.])
for i, data_name in enumerate(data_name_list):
    axes[i].plot(np.arange(1, n_commun+1), loss_list_uniform_logmean[i, :], color="r", label="Uniform", linewidth=3)
    axes[i].plot(np.arange(1, n_commun+1), loss_list_Ada_OSMD_logmean[i, :], color="g", label="Ada-OSMD")
    axes[i].plot(np.arange(1, n_commun+1), loss_list_MABS_logmean[i, :], color="m", label="MABS")
    axes[i].plot(np.arange(1, n_commun+1), loss_list_VRB_logmean[i, :], color="c", label="VRB")
    axes[i].plot(np.arange(1, n_commun+1), loss_list_Avare_logmean[i, :], color="y", label="Avare")
    axes[i].tick_params(axis='both', which='major', labelsize=25)
    axes[i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axes[i].set_ylabel(r'log(Loss)', fontsize=25)
    axes[i].set_xlabel('Communication Rounds', fontsize=25)
    axes[i].set_title(data_name, fontsize=25)
fig.legend(labels=['Uniform', "Ada-OSMD", "MABS", "VRB", "Avare"], loc='upper center', ncol=5, fontsize=25)
plt.savefig('plots_logistic/logistic_loss_cv.png', dpi=400, bbox_inches='tight')
plt.savefig('plots_logistic/logistic_loss_cv.eps', dpi=400, bbox_inches='tight', format='eps')
plt.close(fig)

# plot accuracy
fig, axes = plt.subplots(1, 3, figsize=[32., 12.])
for i, data_name in enumerate(data_name_list):
    axes[i].plot(np.arange(1, n_commun+1), accu_list_uniform[i, :], color="r", label="Uniform", linewidth=3)
    axes[i].plot(np.arange(1, n_commun+1), accu_list_Ada_OSMD[i, :], color="g", label="Ada-OSMD")
    axes[i].plot(np.arange(1, n_commun+1), accu_list_MABS[i, :], color="m", label="MABS")
    axes[i].plot(np.arange(1, n_commun+1), accu_list_VRB[i, :], color="c", label="VRB")
    axes[i].plot(np.arange(1, n_commun+1), accu_list_Avare[i, :], color="c", label="Avare")
    axes[i].tick_params(axis='both', which='major', labelsize=25)
    axes[i].yaxis.set_ticks(np.linspace(0.0, 1.0, num=5))
    axes[i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axes[i].set_ylabel(r'Accuracy', fontsize=25)
    axes[i].set_xlabel('Communication Rounds', fontsize=25)
    axes[i].set_title(data_name, fontsize=25)
fig.legend(labels=['Uniform', "Ada-OSMD", "MABS", "VRB", "Avare"], loc='upper center', ncol=5, fontsize=25)
plt.savefig('plots_logistic/logistic_accuracy_cv.png', dpi=400, bbox_inches='tight')
plt.savefig('plots_logistic/logistic_accuracy_cv.eps', dpi=400, bbox_inches='tight', format='eps')
plt.close(fig)

# plot loss and accuracy
fig, axes = plt.subplots(2, 3, figsize=[32., 24.])
for i, data_name in enumerate(data_name_list):
    axes[0, i].plot(np.arange(1, n_commun+1), loss_list_uniform_logmean[i, :], color="r", label="Uniform", linewidth=3)
    axes[0, i].plot(np.arange(1, n_commun+1), loss_list_Ada_OSMD_logmean[i, :], color="g", label="Ada-OSMD")
    axes[0, i].plot(np.arange(1, n_commun+1), loss_list_MABS_logmean[i, :], color="m", label="MABS")
    axes[0, i].plot(np.arange(1, n_commun+1), loss_list_VRB_logmean[i, :], color="c", label="VRB")
    axes[0, i].plot(np.arange(1, n_commun+1), loss_list_Avare_logmean[i, :], color="y", label="Avare")
    axes[0, i].tick_params(axis='both', which='major', labelsize=30)
    axes[0, i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axes[0, i].set_ylabel(r'log(Loss)', fontsize=30)
    axes[0, i].set_title(data_name, fontsize=35)

    axes[1, i].plot(np.arange(1, n_commun+1), accu_list_uniform[i, :], color="r", label="Uniform", linewidth=3)
    axes[1, i].plot(np.arange(1, n_commun+1), accu_list_Ada_OSMD[i, :], color="g", label="Ada-OSMD")
    axes[1, i].plot(np.arange(1, n_commun+1), accu_list_MABS[i, :], color="m", label="MABS")
    axes[1, i].plot(np.arange(1, n_commun+1), accu_list_VRB[i, :], color="c", label="VRB")
    axes[1, i].plot(np.arange(1, n_commun+1), accu_list_Avare[i, :], color="y", label="Avare")
    axes[1, i].tick_params(axis='both', which='major', labelsize=30)
    axes[1, i].yaxis.set_ticks(np.linspace(0.0, 1.0, num=5))
    axes[1, i].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axes[1, i].set_ylabel(r'Accuracy', fontsize=30)
    axes[1, i].set_xlabel('Communication Rounds', fontsize=35)
fig.legend(labels=['Uniform', "Ada-OSMD", "MABS", "VRB", "Avare"], loc='upper center', ncol=5, fontsize=25)
plt.savefig('plots_logistic/logistic_cv.png', dpi=400, bbox_inches='tight')
plt.savefig('plots_logistic/logistic_cv.eps', dpi=400, bbox_inches='tight', format='eps')
plt.close(fig)