# -*- coding: utf-8 -*-
"""
Created on Thu Aug 31 10:18:44 2023

@author: xiatong
"""


import numpy as np
import matplotlib.pyplot as plt


loss = np.load('reconstrution_loss_noDec.npy')
loss2 = np.load('reconstrution_loss_Dec_0.65.npy')
loss3 = np.load('reconstrution_loss_Dec_0.3.npy')

plt.figure(dpi=300, figsize=(4,3))

plt.plot(loss, label='W/o $\mathcal{L}_{dec}$')
plt.plot(loss2, label='With $\mathcal{L}_{dec}$ (c=0.65)')
plt.plot(loss3, label='With $\mathcal{L}_{dec}$ (c=0.40)')
plt.xlim([0,3000])

plt.xlabel('Iteration')
plt.ylabel('MSE')
plt.grid()
plt.legend()



## epoch 20
# loss1 = [0.9586, 0.9574, 0.9486, 0.9417, 0.9289, 0.9210, 0.9181, 0.9118, 0.9070, 0.9034, 0.8980, 0.8884, 0.8731, 0.8489, 0.8285, 0.8099, 0.7951, 0.7780, 0.7593, 0.7501]
# loss2 = [0.9578, 0.9521, 0.9414, 0.9345, 0.9285, 0.9222, 0.9179, 0.9124, 0.9083, 0.9028, 0.8965, 0.8817, 0.8729, 0.8529, 0.8330, 0.8080, 0.7997, 0.7792, 0.7687, 0.7584]
# loss3 = [0.9535, 0.9476, 0.9380, 0.9285, 0.9170, 0.9085, 0.9032, 0.8985, 0.8938, 0.8907, 0.8855, 0.8780, 0.8647, 0.8510, 0.8297, 0.8099, 0.7878, 0.7681, 0.7520, 0.7456]
# loss4 = [0.9599, 0.9542, 0.9437, 0.9403, 0.9375, 0.9287, 0.9230, 0.9176, 0.9149, 0.9122, 0.9079, 0.9057, 0.8998, 0.8933, 0.8804, 0.8668, 0.8402, 0.8092, 0.7698, 0.7290]
# loss5 = [0.9534, 0.9489, 0.9400, 0.9304, 0.9219, 0.9142, 0.9103, 0.9073, 0.8995, 0.8938, 0.8851, 0.8725, 0.8554, 0.8391, 0.8167, 0.8021, 0.7877, 0.7671, 0.7583, 0.7488]

##Epoch 30
# loss1 = [0.9575, 0.9535, 0.9413, 0.9341, 0.9251, 0.9203, 0.9205, 0.9140, 0.9094, 0.9069, 0.8997, 0.8933, 0.8861, 0.8698, 0.8520, 0.8223, 0.7759, 0.7438, 0.7017, 0.6812]
# loss2 = [0.9503, 0.9455, 0.9355, 0.9257, 0.9240, 0.9141, 0.9065, 0.8991, 0.8989, 0.8945, 0.8862, 0.8789, 0.8617, 0.8476, 0.8239, 0.7971, 0.7703, 0.7557, 0.7366, 0.7204]
# loss3 = [0.9543, 0.9476, 0.9383, 0.9275, 0.9182, 0.9106, 0.9064, 0.9054, 0.9009, 0.8954, 0.8866, 0.8760, 0.8607, 0.8447, 0.8182, 0.7974, 0.7772, 0.7586, 0.7435, 0.7312]
# loss4 = [0.9566, 0.9526, 0.9441, 0.9373, 0.9305, 0.9240, 0.9187, 0.9158, 0.9121, 0.9105, 0.9060, 0.8968, 0.8890, 0.8758, 0.8582, 0.8401, 0.8163, 0.7972, 0.7787, 0.7629]
# loss5 = [0.9547, 0.9459, 0.9349, 0.9200, 0.9104, 0.9041, 0.9020, 0.8986, 0.8938, 0.8927, 0.8854, 0.8775, 0.8657, 0.8503, 0.8351, 0.8124, 0.7937, 0.7745, 0.7533, 0.7423]

# ##Epoch 40
# loss1 = [0.9573, 0.9510, 0.9444, 0.9364, 0.9270, 0.9225, 0.9131, 0.9105, 0.9063,
#         0.9002, 0.8915, 0.8837, 0.8729, 0.8549, 0.8349, 0.8138, 0.7914, 0.7751,
#         0.7577, 0.7436, 0.7307, 0.7269, 0.7123, 0.7109, 0.7121, 0.7037, 0.6972,
#         0.6941, 0.6870, 0.6807, 0.6794, 0.6783, 0.6712, 0.6713, 0.6618, 0.6681,
#         0.6501, 0.6622, 0.6581, 0.6632] 
# loss2 = [0.9543, 0.9492, 0.9422, 0.9352, 0.9290, 0.9185, 0.9141, 0.9139, 0.9121,
#         0.9056, 0.9041, 0.8978, 0.8855, 0.8664, 0.8470, 0.8249, 0.8016, 0.7873,
#         0.7702, 0.7586, 0.7467, 0.7327, 0.7250, 0.7171, 0.7076, 0.7009, 0.6983,
#         0.6921, 0.6950, 0.6835, 0.6826, 0.6787, 0.6671, 0.6695, 0.6616, 0.6622,
#         0.6651, 0.6687, 0.6591, 0.6610] 
# loss3 = [0.9515, 0.9476, 0.9415, 0.9336, 0.9213, 0.9118, 0.9062, 0.8982, 0.8968,
#         0.8943, 0.8882, 0.8859, 0.8808, 0.8697, 0.8602, 0.8465, 0.8223, 0.8019,
#         0.7830, 0.7694, 0.7486, 0.7359, 0.7342, 0.7205, 0.7141, 0.7063, 0.6984,
#         0.6941, 0.6852, 0.6944, 0.6833, 0.6774, 0.6681, 0.6742, 0.6713, 0.6663,
#         0.6584, 0.6603, 0.6553, 0.6554]
# loss4 = [0.9577, 0.9525, 0.9417, 0.9322, 0.9279, 0.9207, 0.9123, 0.9074, 0.9054,
#         0.9013, 0.8965, 0.8929, 0.8835, 0.8773, 0.8616, 0.8491, 0.8215, 0.8027,
#         0.7765, 0.7643, 0.7533, 0.7372, 0.7286, 0.7161, 0.7055, 0.6991, 0.6926,
#         0.6877, 0.6828, 0.6702, 0.6667, 0.6628, 0.6652, 0.6599, 0.6467, 0.6512,
#         0.6441, 0.6378, 0.6386, 0.6374]
# loss5 = [0.9560, 0.9483, 0.9431, 0.9350, 0.9299, 0.9234, 0.9147, 0.9101, 0.9063,
#         0.9040, 0.8947, 0.8896, 0.8792, 0.8638, 0.8444, 0.8200, 0.8055, 0.7820,
#         0.7613, 0.7554, 0.7376, 0.7289, 0.7224, 0.7114, 0.7060, 0.7031, 0.6959,
#         0.6889, 0.6836, 0.6830, 0.6778, 0.6708, 0.6672, 0.6658, 0.6624, 0.6586,
#         0.6621, 0.6545, 0.6511, 0.6477]

##Epoch 50
loss1 = [0.9581, 0.9494, 0.9403, 0.9310, 0.9244, 0.9181, 0.9148, 0.9106, 0.9044,
        0.9021, 0.8920, 0.8872, 0.8729, 0.8614, 0.8423, 0.8285, 0.8157, 0.8083,
        0.7914, 0.7838, 0.7731, 0.7662, 0.7597, 0.7522, 0.7459, 0.7436, 0.7395,
        0.7331, 0.7234, 0.7222, 0.7197, 0.7178, 0.7120, 0.7127, 0.7033, 0.6980,
        0.6953, 0.6902, 0.6882, 0.6870, 0.6787, 0.6788, 0.6753, 0.6824, 0.6727,
        0.6737, 0.6740, 0.6649, 0.6618, 0.6670] 
loss2 = [0.9529, 0.9433, 0.9324, 0.9225, 0.9196, 0.9194, 0.9157, 0.9120, 0.9106,
        0.9081, 0.9054, 0.9013, 0.8980, 0.8922, 0.8846, 0.8792, 0.8619, 0.8399,
        0.8251, 0.7993, 0.7842, 0.7669, 0.7399, 0.7188, 0.6979, 0.6977, 0.6836,
        0.6735, 0.6610, 0.6655, 0.6480, 0.6452, 0.6409, 0.6394, 0.6341, 0.6293,
        0.6266, 0.6193, 0.6098, 0.6193, 0.6093, 0.6061, 0.6126, 0.6047, 0.6035,
        0.5909, 0.5949, 0.5939, 0.5978, 0.5914]
loss3 = [0.9474, 0.9361, 0.9213, 0.9147, 0.9117, 0.9042, 0.9016, 0.9000, 0.9013,
        0.8947, 0.8903, 0.8871, 0.8818, 0.8764, 0.8710, 0.8669, 0.8575, 0.8466,
        0.8387, 0.8217, 0.8056, 0.7928, 0.7837, 0.7762, 0.7647, 0.7569, 0.7474,
        0.7470, 0.7383, 0.7272, 0.7178, 0.7151, 0.7053, 0.7063, 0.7015, 0.6943,
        0.6955, 0.6910, 0.6836, 0.6844, 0.6735, 0.6716, 0.6669, 0.6668, 0.6685,
        0.6635, 0.6652, 0.6615, 0.6469, 0.6480]
loss4 = [0.9504, 0.9388, 0.9287, 0.9225, 0.9192, 0.9152, 0.9138, 0.9080, 0.9069,
        0.9073, 0.9056, 0.8980, 0.8975, 0.8934, 0.8897, 0.8804, 0.8754, 0.8682,
        0.8558, 0.8452, 0.8320, 0.8209, 0.8120, 0.8081, 0.7948, 0.7806, 0.7750,
        0.7652, 0.7636, 0.7546, 0.7493, 0.7395, 0.7322, 0.7259, 0.7193, 0.7242,
        0.7108, 0.7095, 0.7059, 0.6934, 0.6974, 0.6889, 0.6897, 0.6896, 0.6847,
        0.6749, 0.6766, 0.6743, 0.6728, 0.6780]
loss5 = [0.9547, 0.9475, 0.9363, 0.9270, 0.9219, 0.9181, 0.9138, 0.9100, 0.9074,
        0.9014, 0.8957, 0.8885, 0.8816, 0.8646, 0.8563, 0.8461, 0.8284, 0.8160,
        0.8087, 0.7935, 0.7842, 0.7793, 0.7687, 0.7562, 0.7493, 0.7440, 0.7347,
        0.7281, 0.7322, 0.7267, 0.7198, 0.7119, 0.7088, 0.7030, 0.6945, 0.6928,
        0.6928, 0.6980, 0.6860, 0.6912, 0.6940, 0.6763, 0.6765, 0.6767, 0.6771,
        0.6769, 0.6667, 0.6667, 0.6532, 0.6616]



plt.figure(dpi=300, figsize=(3,3))

plt.plot(loss1,'.-')
plt.plot(loss2,'.-')
plt.plot(loss3,'.-')
plt.plot(loss4,'.-')
plt.plot(loss5,'.-')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Correlation')
plt.xlim([0,20])
plt.xticks([0,9,19,29,39,49],['0', '10', '20', '30', '40', '50'])


path = 'E://federated//FLea3//logs_all//MobileNet//100//Quality_3//'
plt.figure(dpi=300, figsize=(3,2))
Accs = [] 
acc = []
                
for file in ['FLea_100_alpha']:   
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'The last decorrelation' in line:
                temp = line.strip()[-7:-1]
                acc.append(float(temp)) 
 
acc_mean = [np.mean(acc[i*10:i*10+10]) for i in range(100)]
acc_min = [np.min(acc[i*10:i*10+10]) for i in range(100)]
acc_max = [np.max(acc[i*10:i*10+10]) for i in range(100)]
 
plt.plot(acc_mean, label='W/o $\mathcal{L}_{dec}$')
x = range(1,len(acc_mean)+1)
plt.fill_between(x, acc_min, acc_max, alpha=0.2) 

path = 'E://federated//FLea2//logs_all//MobileNet//Quality_3//seed0//'
Accs = [] 
acc = []
                
for file in ['FLea']:   
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'The last decorrelation' in line:
                temp = line.strip()[-7:-1]
                acc.append(float(temp)) 
 
acc_mean = [np.mean(acc[i*10:i*10+10]) for i in range(100)]
acc_min = [np.min(acc[i*10:i*10+10]) for i in range(100)]
acc_max = [np.max(acc[i*10:i*10+10]) for i in range(100)]
 
plt.plot(acc_mean, label='With $\mathcal{L}_{dec}$')
x = range(1,len(acc_mean)+1)
plt.fill_between(x, acc_min, acc_max, alpha=0.2) 

plt.legend()
plt.xlabel('Round')
plt.ylabel('Correlation')
plt.grid()
#plt.ylim([0.1,1])
# plt.title('IID data')
#plt.xlim([0, 100])
plt.xticks([0,24,49,74,99],['1', '25', '50', '75', '100'])

print('===========================================')


plt.figure(dpi=300, figsize=(4,3))
Accs = [] 
for file in [ 'FedAvg', 'FedNTD', 'FLea', 'FLea_l2' ]:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'GM acc on global data' in line or 'Global Model Acc' in line:
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
 
        
    
    Accs.append(np.max(acc))
    plt.plot(acc,label=file)
    print(file, np.max(acc))

plt.legend()
plt.xlabel('Round')
plt.ylabel('Acc')
plt.grid()
plt.ylim([0.1,0.7])
# plt.title('IID data')
plt.xlim([0, 100])
print('===========================================')


sample = [100,500,1000,2000,5000,10000, 20000,50000]
FedMix = [85.0, 95.04, 97.54,100,100,100,100,100]
Flea =   [70.82,80.17,85.95,90.38, 93.60, 95, 95.3, 95.2]
Flea2 = [50,61, 64.26, 68.87, 85, 89,89,89]

plt.figure(dpi=300, figsize=(4,3))
plt.plot(sample, FedMix, '-s', label = 'FedMix')
plt.plot(sample, Flea, '-s', label='FLea (c=0.65)')
plt.plot(sample, Flea2, '-s', label='FLea (c=0.40)')
plt.xscale("log")
plt.grid()
plt.legend()
plt.xlabel('Training size')
plt.ylabel('Accuracy %')
