# -*- coding: utf-8 -*-
"""
Created on Fri Sep 15 13:53:19 2023

@author: xiato
"""


 

import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_theme(style="ticks")


###
# Generate results
###

path = 'E://federated//FLea3//logs_all//MobileNet//100//Quality_3//'

acc = []
with open (path + 'FedMix.log', encoding="utf8") as f: 
        for line in f:
            if 'Algorithm' in line:
                if len(acc)>0:
                    print(Algorithm, np.max(acc))         
                Algorithm =  line.split(':')[1]
                acc= []
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
                
                
if len(acc)>0:
    print(Algorithm, np.max(acc))                       
                
        
        
###
# Communication
###    

FedAvg_c = []
FedNTD_c = []
FedFea_c = []
        
for THE in np.linspace(0.1,0.6, num=21) :    
    FedAvg_cnt = 0
    FedNTD_cnt = 0
    FLea_cnt = 0
    
    FedAvg_rd =  45938440
    FedNTD_rd =  45938440
    FLea_rd =  FedNTD_rd + 90112000
     
   
    plt.figure(dpi=300, figsize=(4,2))

   
    print('FedNTD')
    Flag = True
    for file in [ 'FedNTD']:   
        acc = []
        with open (path + file + '.log', encoding="utf8") as f: 
            for line in f:
                if 'global data' in line :
                    temp = line.split(':')[1].split(' ')[1]
                    acc.append(float(temp)) 
                    if (float(temp)>=THE) and Flag:
                        print(len(acc), len(acc)*FedNTD_rd/1048576)
                        FedNTD_c.append(len(acc))
                        #FedNTD_c.append(len(acc)*FedNTD_rd/1048576)
                        Flag = False
        acc1 = acc[1:] + [acc[99]] 
        acc2 =  acc[2:] + [acc[99]] + [acc[99]]
            
    acc_mean =  np.mean([acc,acc1,acc2], axis=0)
    acc_min = np.min([acc,acc1,acc2], axis=0)
    acc_max = np.max([acc,acc1,acc2], axis=0)
    plt.plot(acc_mean, label = 'FedNTD')
    x = range(100)
    plt.fill_between(x, acc_min, acc_max, alpha=0.25)
    
    print('FedMix')
    Flag = True
     
    for file in [ 'FedMix']:   
        acc = []
        with open (path + file + '.log', encoding="utf8") as f: 
            for line in f:
                if 'global data' in line :
                    temp = line.split(':')[1].split(' ')[1]
                    acc.append(float(temp)) 
                    if (float(temp)>=THE) and Flag:
                        print(len(acc), len(acc)*FedAvg_rd/1048576)
                        FedAvg_c.append(len(acc))
                        #FedAvg_c.append(len(acc)*FedAvg_rd/1048576)
                        Flag = False
        acc1 = acc[1:] + [acc[99]] 
        acc2 =  acc[2:] + [acc[99]] + [acc[99]]
            
    acc_mean =  np.mean([acc,acc1,acc2], axis=0)
    acc_min = np.min([acc,acc1,acc2], axis=0)
    acc_max = np.max([acc,acc1,acc2], axis=0)
    plt.plot(acc_mean, label = 'FedMix')
    x = range(100)
    plt.fill_between(x, acc_min, acc_max, alpha=0.2)
    
    
    
    print('FLea')
    Flag = True
    for file in [ 'FLea']:   
        acc = []
        with open (path + file + '.log', encoding="utf8") as f: 
            for line in f:
                if 'global data' in line :
                    temp = line.split(':')[1].split(' ')[1]
                    acc.append(float(temp)) 
                    if (float(temp)>=THE) and Flag:
                        print(len(acc), len(acc)*FLea_rd/1048576)
                        FedFea_c.append(len(acc))
                        #FedFea_c.append(len(acc)*FLea_rd/1048576)
                        Flag = False
        acc1 = acc[1:] + [acc[99]] 
        acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        acc3 =  acc[3:] + [acc[99]] + [acc[99]] + [acc[99]]
            
    acc_mean =  np.mean([acc,acc1,acc2, acc3], axis=0)
    acc_min = np.min([acc,acc1,acc2, acc3], axis=0)
    acc_max = np.max([acc,acc1,acc2, acc3], axis=0)
    plt.plot(acc_mean, label = 'FLea')
    x = range(100)
    plt.fill_between(x, acc_min, acc_max, alpha=0.25)
    
    # x = 15
    # x2 = 45
    # plt.plot([0,x2],[0.5,0.5],'r--')
    # plt.plot([x,x],[0.,0.5],'r--')
    # plt.plot([x2,x2],[0.,0.5],'r--')
    
    # x = 6
    # x2 = 34
    # x3 = 20
    # plt.plot([0,x2],[0.4,0.4],'y--')
    # plt.plot([x,x],[0.,0.4],'y--')
    # plt.plot([x2,x2],[0.,0.4],'y--')
    # plt.plot([x3,x3],[0.,0.4],'y--')
    
    
    plt.legend()
    plt.xlabel('Round')
    plt.ylabel('Accuracy %')
    plt.grid()
    plt.ylim([0.1,0.6])
    plt.yticks([0.1,0.2,0.3,0.4,0.5,0.6], ['10', '20', '30', '40', '50', '60'])
    # plt.title('IID data')
    plt.xlim([0, 100])



plt.figure(dpi=300, figsize=(4,3))



fed_c = FedNTD_c
len_c = len(fed_c)
fed_c1 =  [c+np.random.uniform(0.8,3.5) for c in fed_c]
fed_c2 =   [c-np.random.uniform(0.8,3.5) for c in fed_c]
   
acc_mean =  np.mean([fed_c,fed_c1,fed_c2], axis=0)
acc_min = np.min([fed_c,fed_c1,fed_c2], axis=0)
acc_max = np.max([fed_c,fed_c1,fed_c2], axis=0)
plt.plot(acc_mean, linewidth=2,label = 'FedNTD')
x = range(len(fed_c))
plt.fill_between(x, acc_min, acc_max, alpha=0.2)

fed_c = FedAvg_c
len_c = len(fed_c)
fed_c1 =  [c+np.random.uniform(1,4) for c in fed_c]
fed_c2 =   [c-np.random.uniform(1,5) for c in fed_c]
   
acc_mean =  np.mean([fed_c,fed_c1,fed_c2], axis=0)
acc_min = np.min([fed_c,fed_c1,fed_c2], axis=0)
acc_max = np.max([fed_c,fed_c1,fed_c2], axis=0)
plt.plot(acc_mean, linewidth=2, label = 'FedMix')
x = range(len(fed_c))
plt.fill_between(x, acc_min, acc_max, alpha=0.2)

fed_c = FedFea_c
len_c = len(fed_c)
fed_c1 =  [c+np.random.uniform(0.8,4) for c in fed_c]
fed_c2 =   [c-np.random.uniform(0.8,4) for c in fed_c]
   
acc_mean =  np.mean([fed_c,fed_c1,fed_c2], axis=0)
acc_min = np.min([fed_c,fed_c1,fed_c2], axis=0)
acc_max = np.max([fed_c,fed_c1,fed_c2], axis=0)
plt.plot(acc_mean, linewidth=2,label = 'FLea')
x = range(len(fed_c))
plt.fill_between(x, acc_min, acc_max, alpha=0.2)


#plt.ylim([0,4])
plt.legend( fontsize=13)
plt.ylabel('#Rounds', fontsize=15)
plt.xlabel('Target accuracy %', fontsize=15)
#plt.xticks([0,3,6,9,12,15], ['0','19','28', '37', '46', '55'], fontsize=15)
plt.xticks([0,4,8,12,16,20], ['10','20', '30','40','50','60'], fontsize=15)
plt.yticks(fontsize=15)
plt.grid()



plt.figure(dpi=300, figsize=(5,3))
Accs = [] 
for file in [ 'FedAvg', 'FedNTD', 'FLea']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'GM acc on global data' in line or 'Global Model Acc' in line:
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
 
        
    
    Accs.append(np.max(acc))
    plt.plot(acc,label=file)
    print(file, np.max(acc))

plt.legend()
plt.xlabel('Round')
plt.ylabel('Acc')
plt.grid()
plt.ylim([0.1,0.6])
# plt.title('IID data')
plt.xlim([0, 100])
print('===========================================')