import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#log_baseline = pd.read_csv('multi_attack_trained_models/baseline_none_wave_0_200000_15000_index_20k.npy/log.csv').values#['loss'].values
log_baseline = pd.read_csv('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/active_uncertainty_wave_10_0_200000_250__20k.npy_None/log.csv').values #['loss'].values
#log_next = pd.read_csv('multi_attack_trained_models/check_active_random_wave_1_0_200000_0__20k.npy_None/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/check_2_active_random_wave_1_0_200000_250__20k.npy_[0, 1, 2, 3, 4]/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/check_5_active_random_wave_1_0_200000_250__20k.npy_None/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/check_6_active_random_wave_1_0_200000_0__20k.npy_None_128/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/test_active_uncertainty_wave_3_0_200000_200__50k.npy_None_128/log.csv').values
#log_next = pd.read_csv('multi_attack_trained_models/test_retrain_active_uncertainty_wave_3_0_200000_2000_ndex.npy_None_768/log.csv').values
log_next = pd.read_csv('multi_attack_trained_models/test_update_active_uncertainty_balance_wave_10_0_200000_2000/log.csv').values
print(len(log_next))
#log_next = log_next[-768*3:]

attack_rank = pd.read_csv('multi_attack_trained_models/test_baseline_random_wave_0_200000_20000_index_50k.npy/attack_rank.csv', index_col=0).to_numpy()
attack_rank_multi = pd.read_csv('multi_attack_trained_models/test_baseline_random_wave_0_200000_20000_index_50k.npy/attack_rank_multi.csv', index_col=0).to_numpy()
attack_rank_multi = np.mean(attack_rank_multi, axis = 1)
print(attack_rank_multi.shape)
print(float(log_baseline[0][0].split(';')[4]))
print(float(log_baseline[0][0].split(';')[0]))
loss_baseline = []
for i in range(len(log_baseline)):
    loss_baseline.append(float(log_baseline[i][0].split(';')[2]))

print(loss_baseline[:10])

for i in range(len(log_next)):
    loss_baseline.append(float(log_next[i][0].split(';')[2]))

#loss_baseline_next = loss_baseline
'''
for i in range(int(len(log_next)/2)):
    if i%2 == 0:
        loss_baseline.append(float(log_next[i*2][0].split(';')[2]))
        loss_baseline.append(float(log_next[i*2+1][0].split(';')[2]))
    if i < 10 and i%2==0:
        print(i*2)
        print(i*2+1)
        print(log_next[i*2][0].split(';')[2])
        print(log_next[i*2+1][0].split(';')[2])
'''

for i in range(len(log_next)):
    if i < 10:
        print(i)
        print(log_next[i][0].split(';')[2])



loss_baseline_val = []
for i in range(len(log_baseline)):
    loss_baseline_val.append(float(log_baseline[i][0].split(';')[4]))

print(loss_baseline_val[:10])

#loss_baseline_next = loss_baseline
for i in range(len(log_next)):
    loss_baseline_val.append(float(log_next[i][0].split(';')[4]))

plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Loss Value')
plt.xlabel('Training epoch')
plt.plot(loss_baseline)
#plt.plot(loss_baseline[:len(log_baseline)])
plt.plot(loss_baseline_val)
plt.legend()
plt.savefig('loss_check.png')
plt.clf()

plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Loss Value')
plt.xlabel('Training epoch')
plt.plot(loss_baseline_val[:len(log_baseline)] / np.max(loss_baseline_val[:len(log_baseline)]), label='val')
#plt.plot(loss_baseline[:len(log_baseline)])
plt.plot(attack_rank/3329, label='attack rank')
plt.plot(attack_rank_multi/3329, label='attack rank average 10 keys', c='r')
plt.legend()
plt.savefig('rank_check.png')
'''
plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Loss Value')
plt.xlabel('Training epoch')
plt.plot(loss_baseline[:len(log_baseline)], label='train loss')
plt.plot(loss_baseline_val[:len(log_baseline)], label='validation loss')
plt.legend()
plt.savefig('loss_check_eval.png')

plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Loss Value')
plt.xlabel('Training epoch')
plt.plot(loss_baseline, label='train loss')
plt.plot(loss_baseline_val, label='validation loss')
plt.legend()
plt.savefig('loss_check_eval_full.png')


plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Loss Value')
plt.xlabel('Training epoch')
plt.plot(loss_baseline,label='Active training with only 2 samples, loss of 15000 baseline samples')
plt.plot(loss_baseline[:len(log_baseline)], label='Baseline Training with 15000 samples, minmax')
#plt.plot(loss_baseline_val)
plt.legend()
plt.savefig('loss_check_eval_schedule.png')
'''