import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

###############config_variables###############
data_file_name1 = 'result/simulation_result_5nodes_cifar100_iidFalse_H50_alpha1_iter10000_batch_size64'
data_file_name2 = 'result/simulation_result_5nodes_cifar100_iidFalse_H50_alpha10_iter10000_batch_size64'
data_file_name3 = 'result/scaffold_simulation_result_5nodes_cifar100_iidFalse_H50_alpha1_iter20000_batch_size64'
data_file_name4 = 'result/scaffold_simulation_result_5nodes_cifar100_iidFalse_H50_alpha10_iter20000_batch_size64'
final_figure_name = "iidFalse_cifar10_scaffold"
iteration =   10 **4
frequency = 100
###############config_variables###############

def extract_data(loss_data):
    l = []
    a = []
    t = []
    c = []
    for i in range(len(loss_data)):
        l.append([[loss_data[i][j][0][k][0] for k in range(len(loss_data[0][0][0]))] for j in range(len(loss_data[i]))])
        a.append([[loss_data[i][j][0][k][1] for k in range(len(loss_data[0][0][0]))] for j in range(len(loss_data[i]))])
        t.append([loss_data[i][j][1] for j in range(len(loss_data[i]))])
        c.append([loss_data[i][j][2] for j in range(len(loss_data[i]))])
    return l, a, t, c


def linear_estimate(data, t, ind):
    res = []
    for i in ind:
        for j in range(len(t)):
            if t[j] >= i:
                res.append(data[j - 1] + (i - t[j - 1]) / (t[j] - t[j - 1]) * (data[j] - data[j - 1]))
                break
    return res


with open("./result/"+data_file_name1, "r") as fp:
    data = json.load(fp)
with open("./result/"+data_file_name2, "r") as fp:
    data2 = json.load(fp)
data += data2
with open("./result/"+data_file_name3, "r") as fp:
    data2 = json.load(fp)
data += data2
with open("./result/"+data_file_name4, "r") as fp:
    data2 = json.load(fp)
data += data2

#%%
iter_loss, iter_accuracy ,iter_t, iter_comm = extract_data(data)
#%%
t_loss = []
t_accuracy = []
t_comm = []
time = range(0, iteration, frequency)
for i in range(len(iter_loss)):
    print(i)
    t_loss.append([linear_estimate(iter_loss[i][j], iter_t[i][j], time) for j in range(len(iter_loss[i]))])
    t_accuracy.append([linear_estimate(iter_accuracy[i][j], iter_t[i][j], time) for j in range(len(iter_loss[i]))])
    t_comm.append([linear_estimate(iter_comm[i][j], iter_t[i][j], time) for j in range(len(iter_comm[i]))])
#%%
av_iter_loss = []
std_iter_loss = []
av_t_loss = []
std_t_loss = []
av_iter_accuracy = []
std_iter_accuracy = []
av_t_accuracy = []
std_t_accuracy = []
av_iter_comm = []
std_iter_comm = []
av_t_comm = []
std_t_comm = []
for i in range(len(iter_loss)):
    av_iter_loss.append(np.average(iter_loss[i], axis=0))
    std_iter_loss.append(np.std(iter_loss[i], axis=0))
    av_t_loss.append(np.average(t_loss[i], axis=0))
    std_t_loss.append(np.std(t_loss[i], axis=0))
    av_iter_accuracy.append(np.average(iter_accuracy[i], axis=0))
    std_iter_accuracy.append(np.std(iter_accuracy[i], axis=0))
    av_t_accuracy.append(np.average(t_accuracy[i], axis=0))
    std_t_accuracy.append(np.std(t_accuracy[i], axis=0))
    av_iter_comm.append(np.average(iter_comm[i], axis=0))
    std_iter_comm.append(np.std(iter_comm[i], axis=0))
    av_t_comm.append(np.average(t_comm[i], axis=0))
    std_t_comm.append(np.std(t_comm[i], axis=0))

av_iter_loss = np.array([av_iter_loss[i][:iteration//frequency] for i in range(len(av_iter_loss))])
std_iter_loss = np.array([std_iter_loss[i][:iteration//frequency] for i in range(len(std_iter_loss))])
av_iter_comm = np.array([av_iter_comm[i][:iteration//frequency] for i in range(len(av_iter_comm))])
std_iter_comm = np.array([std_iter_comm[i][:iteration//frequency] for i in range(len(std_iter_comm))])
av_iter_accuracy = np.array([av_iter_accuracy[i][:iteration//frequency] for i in range(len(av_iter_accuracy))])
std_iter_accuracy = np.array([std_iter_accuracy[i][:iteration//frequency] for i in range(len(std_iter_accuracy))])

av_t_loss = np.array([av_t_loss[i][:iteration//frequency] for i in range(len(av_t_loss))])
std_t_loss = np.array([std_t_loss[i][:iteration//frequency] for i in range(len(std_t_loss))])
av_t_comm = np.array([av_t_comm[i][:iteration//frequency] for i in range(len(av_t_comm))])
std_t_comm = np.array([std_t_comm[i][:iteration//frequency] for i in range(len(std_t_comm))])
av_t_accuracy = np.array([av_t_accuracy[i][:iteration//frequency] for i in range(len(av_t_accuracy))])
std_t_accuracy = np.array([std_t_accuracy[i][:iteration//frequency] for i in range(len(std_t_accuracy))])



#%%
fig,ax = plt.subplots(figsize=(24,10),nrows=1, ncols=2)
plt.rcParams.update({'font.size': 19})
plt.xticks(fontsize = 19)
plt.yticks(fontsize = 19)
colors = matplotlib.cm.tab20(range(20))
b=0
markers=["o","X","P","^","v","s","h","<",">","d","*"]
every=[5,5,5,5,5,5,5,5,5,5,5]
order = [10,0,9,5,5,6,5,6,6]

###############loss_iteration###############
y = range(0,iteration//frequency)
ax[0].plot(y,av_iter_loss[b:,0:].T)
for i,line in enumerate(ax[0].get_lines()):
    line.set_marker(markers[i])
    line.set_markevery(5)
    # line.set_color(colors[i])
    line.set_markersize(10)
for i in range(b,len(av_iter_loss)):
    ax[0].fill_between(y, av_iter_loss[i,0:].T - 1*std_iter_loss[i,0:].T, av_iter_loss[i,0:].T + 1*std_iter_loss[i,0:].T,
                  alpha=0.2)
ax[0].set_ylabel('Training global loss')
ax[0].set_xlabel('Iteration($10^2$)')
# ax[0].legend([r"FedAvg/ $\tau$=50",r"FedALS/ $\tau$=50/ $\alpha$ = 10", r"FedAvg/ $\tau$=500"])
ax[0].legend([r"FedAvg/ $\tau$=50",r"FedALS/ $\tau$=50/ $\alpha$ = 10", r"SCAFFOLD/ $\tau$=50",r"FedALS + SCAFFOLD/ $\tau$=50/ $\alpha$ = 10",])
ax[0].grid(True,which="both")

ax[1].plot(y,av_iter_accuracy[b:,0:].T)
for i,line in enumerate(ax[1].get_lines()):
    line.set_marker(markers[i])
    line.set_markevery(5)
    # line.set_color(colors[i])
    line.set_markersize(10)
for i in range(b,len(av_iter_loss)):
    ax[1].fill_between(y, av_iter_accuracy[i,0:].T - 1*std_iter_accuracy[i,0:].T, av_iter_accuracy[i,0:].T + 1*std_iter_accuracy[i,0:].T,
                  alpha=0.2)
ax[1].set_ylabel('Test accuracy')
ax[1].set_xlabel('Iteration ($10^2$)')
ax[1].legend([r"FedAvg/ $\tau$=50",r"FedALS/ $\tau$=50/ $\alpha$ = 10", r"SCAFFOLD/ $\tau$=50",r"FedALS + SCAFFOLD/ $\tau$=50/ $\alpha$ = 10",])
# ax[1].legend([r"FedAvg/ $\tau$=50",r"FedALS/ $\tau$=50/ $\alpha$ = 10", r"FedAvg/ $\tau$=500"])
ax[1].grid(True,which="both")
###############loss_iteration###############



plt.savefig("./figures/paper_"+final_figure_name+".pdf",dpi =600,bbox_inches='tight',format='pdf')


