# -*- coding: utf-8 -*-
"""
Created on Thu Nov 16 10:09:45 2023

@author: xiato
"""
import numpy as np
import matplotlib.pyplot as plt


path = 'run//'
plt.figure(dpi=300, figsize=(3,2))
Accs = [] 
acc = []
                
for file in ['FLea_audio_alpha8_e3']:   
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'The last decorrelation' in line:
                temp = line.strip()[-7:-1]
                acc.append(float(temp)) 
 
K=7
acc_mean = [np.mean(acc[i*K:i*K+K]) for i in range(100)]
acc_min = [np.min(acc[i*K:i*K+K]) for i in range(100)]
acc_max = [np.max(acc[i*K:i*K+K]) for i in range(100)]
 
plt.plot(acc_mean, label='W/o $\mathcal{L}_{dec}$')
x = range(1,len(acc_mean)+1)
plt.fill_between(x, acc_min, acc_max, alpha=0.2) 
print('average:', np.mean(acc_mean))

 
for file in [ 'FLea_audio_frac_0.9 ']:   
    acc = []
    with open (path + file + '.log', encoding="utf8") as f: 
        for line in f:
            if 'global data' in line :
                temp = line.split(':')[1].split(' ')[1]
                acc.append(float(temp)) 
    print(np.max(acc))
               
    acc1 = acc[1:] + [acc[99]] 
    acc2 =  acc[2:] + [acc[99]] + [acc[99]]
        
acc_mean =  np.mean([acc,acc1,acc2], axis=0)
acc_min = np.min([acc,acc1,acc2], axis=0)
acc_max = np.max([acc,acc1,acc2], axis=0)
plt.plot(acc_mean)
x = range(100)
plt.fill_between(x, acc_min, acc_max, alpha=0.25)





# plt.figure(dpi=300, figsize=(3,2))
# Accs = [] 
# acc = []
                
# for file in ['FLea_cifar10_alpha3_e3']:   
#     with open (path + file + '.log', encoding="utf8") as f: 
#         for line in f:
#             if 'The last decorrelation' in line:
#                 temp = line.strip()[-7:-1]
#                 acc.append(float(temp)) 
 
# K=10
# acc_mean = [np.mean(acc[i*K:i*K+K]) for i in range(100)]
# acc_min = [np.min(acc[i*K:i*K+K]) for i in range(100)]
# acc_max = [np.max(acc[i*K:i*K+K]) for i in range(100)]
 
# plt.plot(acc_mean, label='W/o $\mathcal{L}_{dec}$')
# x = range(1,len(acc_mean)+1)
# plt.fill_between(x, acc_min, acc_max, alpha=0.2) 
# print('average:', np.mean(acc_mean))