# -*- coding: utf-8 -*-
"""
Created on Sun Sep 24 03:37:14 2023

@author: xiato
"""
import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_theme(style="ticks")


plt.figure(dpi=300, figsize=(5,3))


ax = plt.subplot(2,1,1)

path = 'E://federated//FLea3//logs_all//MobileNet//100//Quality_3//'
   
print('FedNTD')
Flag = True
for file in [ 'FedNTD']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
    print(np.max(acc))
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2], axis=0)
acc_min = np.min([acc,acc1,acc2], axis=0)
acc_max = np.max([acc,acc1,acc2], axis=0)
plt.plot(acc_mean, label = 'FedNTD')
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.25)

print('FedMix')
Flag = True
 
for file in [ 'FedMix']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2], axis=0)
acc_min = np.min([acc,acc1,acc2], axis=0)
acc_max = np.max([acc,acc1,acc2], axis=0)
plt.plot(acc_mean, label = 'FedMix')
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.2)



print('FLea')
Flag = True
for file in [ 'FLea']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
               
    acc1 = acc[1:] + [acc[-1]] 
    acc2 =  acc[2:] + [acc[-1]] + [acc[-1]]
    acc3 =  acc[3:] + [acc[-1]] + [acc[-1]] + [acc[-1]]
        
acc_mean =  np.mean([acc,acc1,acc2, acc3], axis=0)
acc_min = np.min([acc,acc1,acc2, acc3], axis=0)
acc_max = np.max([acc,acc1,acc2, acc3], axis=0)
plt.plot(acc_mean, label = 'FLea')
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.25)
#plt.xticks([0,20,40,60, 80,100], [' ', ' ', ' ', ' ', ' ', ' '])
plt.title('100 clients (Quantity(3))')
plt.ylabel('Accuracy %')
plt.grid()#
plt.ylim([0.1,0.65])
plt.yticks([0.2,0.4, 0.6], ['20', '40', '60'])
plt.xlim([0, 100])

ax = plt.subplot(2,1,2)

path = 'E://federated//FLea3//logs_all//MobileNet//500//'
   
print('FedNTD')
Flag = True
for file in [ 'FedNTD_500_dir0.1']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2], axis=0)
acc_min = np.min([acc,acc1,acc2], axis=0)
acc_max = np.max([acc,acc1,acc2], axis=0)
plt.plot(acc_mean, label = 'FedNTD')
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.25)

print('FedMix')
Flag = True
 
for file in [ 'FedMix_500_dir0.1']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2], axis=0)
acc_min = np.min([acc,acc1,acc2], axis=0)
acc_max = np.max([acc,acc1,acc2], axis=0)
plt.plot(acc_mean, label = 'FedMix')
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.2)



print('FLea')
Flag = True
for file in [ 'FLea_500_dir0.1']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
    acc3 =  acc[3:] + [acc[99]] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2, acc3], axis=0)
acc_min = np.min([acc,acc1,acc2, acc3], axis=0)
acc_max = np.max([acc,acc1,acc2, acc3], axis=0)
plt.plot(acc_mean, label = 'FLea')
x = range(100)
# plt.fill_between(x, acc_min, acc_max, alpha=0.25)
plt.title('500 clients (Dirichlet(0.1))')
    

plt.xlabel('Round')
plt.ylabel('Accuracy %')
plt.grid()
plt.ylim([0.1,0.5])
plt.yticks([0.2,0.4], ['20', '40',])
# plt.title('IID data')
plt.xlim([0, 100])
plt.legend(loc='lower right', bbox_to_anchor=(1.02,1.2))

plt.subplots_adjust(hspace=0.57,wspace=0.22)