from matplotlib import pyplot as plt
from matplotlib import style
import numpy as np

plt.style.use('ggplot')
plt.figure(figsize=(7,5))
# TTSA
TTSA_data_1 = np.load("TTSA_result/1.npy")
TTSA_train_loss_1 = TTSA_data_1[:,0]
TTSA_time_1 = TTSA_data_1[:,2]

TTSA_data_2 = np.load("TTSA_result/2.npy")
TTSA_train_loss_2 = TTSA_data_2[:,0]
TTSA_time_2 = TTSA_data_2[:,2]

TTSA_data_3 = np.load("TTSA_result/3.npy")
TTSA_train_loss_3 = TTSA_data_3[:,0]
TTSA_time_3 = TTSA_data_3[:,2]

TTSA_data_4 = np.load("TTSA_result/4.npy")
TTSA_train_loss_4 = TTSA_data_4[:,0]
TTSA_time_4 = TTSA_data_4[:,2]

TTSA_data_5 = np.load("TTSA_result/5.npy")
TTSA_train_loss_5 = TTSA_data_5[:,0]
TTSA_time_5 = TTSA_data_5[:,2]

# manage data
TTSA_loss_all = np. stack((TTSA_train_loss_1, TTSA_train_loss_2, TTSA_train_loss_3, TTSA_train_loss_4, TTSA_train_loss_5), axis = 1)
TTSA_loss_all = TTSA_loss_all.T
TTSA_loss_mean = TTSA_loss_all.mean(axis = 0) # axis
TTSA_loss_min = np.min(TTSA_loss_all, axis = 0)
TTSA_loss_max = np.max(TTSA_loss_all, axis = 0)

TTSA_time_all = np. stack((TTSA_time_1, TTSA_time_2, TTSA_time_3, TTSA_time_4, TTSA_time_5), axis = 1)
TTSA_time_mean = TTSA_time_all.mean(axis = 1) 

plt.plot(TTSA_time_mean, TTSA_loss_mean, label = "TTSA", color="gold")
plt.fill_between(
    TTSA_time_mean, TTSA_loss_max, TTSA_loss_min,
    where = (TTSA_loss_max > TTSA_loss_min), 
    interpolate=False, color="gold", alpha=0.1
)

# BSA
BSA_data_1 = np.load("BSA_result/1.npy")
BSA_train_loss_1 = BSA_data_1[:,0]
BSA_time_1 = BSA_data_1[:,2]

BSA_data_2 = np.load("BSA_result/2.npy")
BSA_train_loss_2 = BSA_data_2[:,0]
BSA_time_2 = BSA_data_2[:,2]

BSA_data_3 = np.load("BSA_result/3.npy")
BSA_train_loss_3 = BSA_data_3[:,0]
BSA_time_3 = BSA_data_3[:,2]

BSA_data_4 = np.load("BSA_result/4.npy")
BSA_train_loss_4 = BSA_data_4[:,0]
BSA_time_4 = BSA_data_4[:,2]

BSA_data_5 = np.load("BSA_result/5.npy")
BSA_train_loss_5 = BSA_data_5[:,0]
BSA_time_5 = BSA_data_5[:,2]

# manage data
BSA_loss_all = np. stack((BSA_train_loss_1, BSA_train_loss_2, BSA_train_loss_3, BSA_train_loss_4, BSA_train_loss_5), axis = 1)
BSA_loss_all = BSA_loss_all.T
BSA_loss_mean = BSA_loss_all.mean(axis = 0) # axis
BSA_loss_min = np.min(BSA_loss_all, axis = 0)
BSA_loss_max = np.max(BSA_loss_all, axis = 0)

BSA_time_all = np. stack((BSA_time_1, BSA_time_2, BSA_time_3, BSA_time_4, BSA_time_5), axis = 1)
BSA_time_mean = BSA_time_all.mean(axis = 1) 

plt.plot(BSA_time_mean, BSA_loss_mean, label = "BSA", color="lightseagreen")
plt.fill_between(
    BSA_time_mean, BSA_loss_max, BSA_loss_min,
    where = (BSA_loss_max > BSA_loss_min), 
    interpolate=False, color="lightseagreen", alpha=0.1
)

# stocBiO
stocBiO_data_1 = np.load("stocBiO_result/1.npy")
stocBiO_train_loss_1 = stocBiO_data_1[:,0]
stocBiO_time_1 = stocBiO_data_1[:,2]

stocBiO_data_2 = np.load("stocBiO_result/2.npy")
stocBiO_train_loss_2 = stocBiO_data_2[:,0]
stocBiO_time_2 = stocBiO_data_2[:,2]

stocBiO_data_3 = np.load("stocBiO_result/3.npy")
stocBiO_train_loss_3 = stocBiO_data_3[:,0]
stocBiO_time_3 = stocBiO_data_3[:,2]

stocBiO_data_4 = np.load("stocBiO_result/4.npy")
stocBiO_train_loss_4 = stocBiO_data_4[:,0]
stocBiO_time_4 = stocBiO_data_4[:,2]

stocBiO_data_5 = np.load("stocBiO_result/5.npy")
stocBiO_train_loss_5 = stocBiO_data_5[:,0]
stocBiO_time_5 = stocBiO_data_5[:,2]

# manage data
stocBiO_loss_all = np. stack((stocBiO_train_loss_1, stocBiO_train_loss_2, stocBiO_train_loss_3, stocBiO_train_loss_4, stocBiO_train_loss_5), axis = 1)
stocBiO_loss_all = stocBiO_loss_all.T
stocBiO_loss_mean = stocBiO_loss_all.mean(axis = 0) # axis
stocBiO_loss_min = np.min(stocBiO_loss_all, axis = 0)
stocBiO_loss_max = np.max(stocBiO_loss_all, axis = 0)

stocBiO_time_all = np. stack((stocBiO_time_1, stocBiO_time_2, stocBiO_time_3, stocBiO_time_4, stocBiO_time_5), axis = 1)
stocBiO_time_mean = stocBiO_time_all.mean(axis = 1) 

plt.plot(stocBiO_time_mean, stocBiO_loss_mean, label = "stocBiO", color="tomato")
plt.fill_between(
    stocBiO_time_mean, stocBiO_loss_max, stocBiO_loss_min,
    where = (stocBiO_loss_max > stocBiO_loss_min), 
    interpolate=False, color="tomato", alpha=0.1
)

# AID_FP
AID_FP_data_1 = np.load("AID_FP_result/1.npy")
AID_FP_train_loss_1 = AID_FP_data_1[:,0]
AID_FP_time_1 = AID_FP_data_1[:,2]

AID_FP_data_2 = np.load("AID_FP_result/2.npy")
AID_FP_train_loss_2 = AID_FP_data_2[:,0]
AID_FP_time_2 = AID_FP_data_2[:,2]

AID_FP_data_3 = np.load("AID_FP_result/3.npy")
AID_FP_train_loss_3 = AID_FP_data_3[:,0]
AID_FP_time_3 = AID_FP_data_3[:,2]

AID_FP_data_4 = np.load("AID_FP_result/4.npy")
AID_FP_train_loss_4 = AID_FP_data_4[:,0]
AID_FP_time_4 = AID_FP_data_4[:,2]

AID_FP_data_5 = np.load("AID_FP_result/5.npy")
AID_FP_train_loss_5 = AID_FP_data_5[:,0]
AID_FP_time_5 = AID_FP_data_5[:,2]

# manage data
AID_FP_loss_all = np. stack((AID_FP_train_loss_1, AID_FP_train_loss_2, AID_FP_train_loss_3, AID_FP_train_loss_4, AID_FP_train_loss_5), axis = 1)
AID_FP_loss_all = AID_FP_loss_all.T
AID_FP_loss_mean = AID_FP_loss_all.mean(axis = 0) # axis
AID_FP_loss_min = np.min(AID_FP_loss_all, axis = 0)
AID_FP_loss_max = np.max(AID_FP_loss_all, axis = 0)

AID_FP_time_all = np. stack((AID_FP_time_1, AID_FP_time_2, AID_FP_time_3, AID_FP_time_4, AID_FP_time_5), axis = 1)
AID_FP_time_mean = AID_FP_time_all.mean(axis = 1) 

plt.plot(AID_FP_time_mean, AID_FP_loss_mean, label = "AID-FP", color="darkgoldenrod")
plt.fill_between(
    AID_FP_time_mean, AID_FP_loss_max, AID_FP_loss_min,
    where = (AID_FP_loss_max > AID_FP_loss_min), 
    interpolate=False, color="darkgoldenrod", alpha=0.1
)

# SUSTAIN
SUSTAIN_data_1 = np.load("SUSTAIN_result/1.npy")
SUSTAIN_train_loss_1 = SUSTAIN_data_1[:,0]
SUSTAIN_time_1 = SUSTAIN_data_1[:,2]

SUSTAIN_data_2 = np.load("SUSTAIN_result/2.npy")
SUSTAIN_train_loss_2 = SUSTAIN_data_2[:,0]
SUSTAIN_time_2 = SUSTAIN_data_2[:,2]

SUSTAIN_data_3 = np.load("SUSTAIN_result/3.npy")
SUSTAIN_train_loss_3 = SUSTAIN_data_3[:,0]
SUSTAIN_time_3 = SUSTAIN_data_3[:,2]

SUSTAIN_data_4 = np.load("SUSTAIN_result/4.npy")
SUSTAIN_train_loss_4 = SUSTAIN_data_4[:,0]
SUSTAIN_time_4 = SUSTAIN_data_4[:,2]

SUSTAIN_data_5 = np.load("SUSTAIN_result/5.npy")
SUSTAIN_train_loss_5 = SUSTAIN_data_5[:,0]
SUSTAIN_time_5 = SUSTAIN_data_5[:,2]

# manage data
SUSTAIN_loss_all = np. stack((SUSTAIN_train_loss_1, SUSTAIN_train_loss_2, SUSTAIN_train_loss_3, SUSTAIN_train_loss_4, SUSTAIN_train_loss_5), axis = 1)
SUSTAIN_loss_all = SUSTAIN_loss_all.T
SUSTAIN_loss_mean = SUSTAIN_loss_all.mean(axis = 0) # axis
SUSTAIN_loss_min = np.min(SUSTAIN_loss_all, axis = 0)
SUSTAIN_loss_max = np.max(SUSTAIN_loss_all, axis = 0)

SUSTAIN_time_all = np. stack((SUSTAIN_time_1, SUSTAIN_time_2, SUSTAIN_time_3, SUSTAIN_time_4, SUSTAIN_time_5), axis = 1)
SUSTAIN_time_mean = SUSTAIN_time_all.mean(axis = 1) 

plt.plot(SUSTAIN_time_mean, SUSTAIN_loss_mean, label = "SUSTAIN", color="royalblue")
plt.fill_between(
    SUSTAIN_time_mean, SUSTAIN_loss_max, SUSTAIN_loss_min,
    where = (SUSTAIN_loss_max > SUSTAIN_loss_min), 
    interpolate=False, color="royalblue", alpha=0.1
)

# VRBO
VRBO_data_1 = np.load("VRBO_result/1.npy")
VRBO_train_loss_1 = VRBO_data_1[:,0]
VRBO_time_1 = VRBO_data_1[:,2]

VRBO_data_2 = np.load("VRBO_result/2.npy")
VRBO_train_loss_2 = VRBO_data_2[:,0]
VRBO_time_2 = VRBO_data_2[:,2]

VRBO_data_3 = np.load("VRBO_result/3.npy")
VRBO_train_loss_3 = VRBO_data_3[:,0]
VRBO_time_3 = VRBO_data_3[:,2]

VRBO_data_4 = np.load("VRBO_result/4.npy")
VRBO_train_loss_4 = VRBO_data_4[:,0]
VRBO_time_4 = VRBO_data_4[:,2]

VRBO_data_5 = np.load("VRBO_result/5.npy")
VRBO_train_loss_5 = VRBO_data_5[:,0]
VRBO_time_5 = VRBO_data_5[:,2]

# manage data
VRBO_loss_all = np. stack((VRBO_train_loss_1, VRBO_train_loss_2, VRBO_train_loss_3, VRBO_train_loss_4, VRBO_train_loss_5), axis = 1)
VRBO_loss_all = VRBO_loss_all.T
VRBO_loss_mean = VRBO_loss_all.mean(axis = 0) # axis
VRBO_loss_min = np.min(VRBO_loss_all, axis = 0)
VRBO_loss_max = np.max(VRBO_loss_all, axis = 0)

VRBO_time_all = np. stack((VRBO_time_1, VRBO_time_2, VRBO_time_3, VRBO_time_4, VRBO_time_5), axis = 1)
VRBO_time_mean = VRBO_time_all.mean(axis = 1) 

plt.plot(VRBO_time_mean, VRBO_loss_mean, label = "VRBO", color="darkorchid")
plt.fill_between(
    VRBO_time_mean, VRBO_loss_max, VRBO_loss_min,
    where = (VRBO_loss_max > VRBO_loss_min), 
    interpolate=False, color="darkorchid", alpha=0.1
)

# MRBO
MRBO_data_1 = np.load("MRBO_result/1.npy")
MRBO_train_loss_1 = MRBO_data_1[:,0]
MRBO_time_1 = MRBO_data_1[:,2]

MRBO_data_2 = np.load("MRBO_result/2.npy")
MRBO_train_loss_2 = MRBO_data_2[:,0]
MRBO_time_2 = MRBO_data_2[:,2]

MRBO_data_3 = np.load("MRBO_result/3.npy")
MRBO_train_loss_3 = MRBO_data_3[:,0]
MRBO_time_3 = MRBO_data_3[:,2]

MRBO_data_4 = np.load("MRBO_result/4.npy")
MRBO_train_loss_4 = MRBO_data_4[:,0]
MRBO_time_4 = MRBO_data_4[:,2]

MRBO_data_5 = np.load("MRBO_result/5.npy")
MRBO_train_loss_5 = MRBO_data_5[:,0]
MRBO_time_5 = MRBO_data_5[:,2]

# manage data
MRBO_loss_all = np. stack((MRBO_train_loss_1, MRBO_train_loss_2, MRBO_train_loss_3, MRBO_train_loss_4, MRBO_train_loss_5), axis = 1)
MRBO_loss_all = MRBO_loss_all.T
MRBO_loss_mean = MRBO_loss_all.mean(axis = 0) # axis
MRBO_loss_min = np.min(MRBO_loss_all, axis = 0)
MRBO_loss_max = np.max(MRBO_loss_all, axis = 0)

MRBO_time_all = np. stack((MRBO_time_1, MRBO_time_2, MRBO_time_3, MRBO_time_4, MRBO_time_5), axis = 1)
MRBO_time_mean = MRBO_time_all.mean(axis = 1) 

plt.plot(MRBO_time_mean, MRBO_loss_mean, label = "MRBO", color="orange")
plt.fill_between(
    MRBO_time_mean, MRBO_loss_max, MRBO_loss_min,
    where = (MRBO_loss_max > MRBO_loss_min), 
    interpolate=False, color="orange", alpha=0.1
)


#TriBo
tribo_data_1 = np.load("TriBo_new_result/1.npy")
tribo_train_loss_1 = tribo_data_1[:,0]
tribo_time_1 = tribo_data_1[:,2]

tribo_data_2 = np.load("TriBo_new_result/2.npy")
tribo_train_loss_2 = tribo_data_2[:,0]
tribo_time_2 = tribo_data_2[:,2]

tribo_data_3 = np.load("TriBo_new_result/3.npy")
tribo_train_loss_3 = tribo_data_3[:,0]
tribo_time_3 = tribo_data_3[:,2]

tribo_data_4 = np.load("TriBo_new_result/4.npy")
tribo_train_loss_4 = tribo_data_4[:,0]
tribo_time_4 = tribo_data_4[:,2]

tribo_data_5 = np.load("TriBo_new_result/5.npy")
tribo_train_loss_5 = tribo_data_5[:,0]
tribo_time_5 = tribo_data_5[:,2]

# manage data
tribo_loss_all = np. stack((tribo_train_loss_1, tribo_train_loss_2, tribo_train_loss_3, tribo_train_loss_4, tribo_train_loss_5), axis = 1)
tribo_loss_all = tribo_loss_all.T
tribo_loss_mean = tribo_loss_all.mean(axis = 0) # axis
tribo_loss_min = np.min(tribo_loss_all, axis = 0)
tribo_loss_max = np.max(tribo_loss_all, axis = 0)

tribo_time_all = np. stack((tribo_time_1, tribo_time_2, tribo_time_3, tribo_time_4, tribo_time_5), axis = 1)
tribo_time_mean = tribo_time_all.mean(axis = 1) 

plt.plot(tribo_time_mean, tribo_loss_mean, label = "TriBO", color="green")
plt.fill_between(
    tribo_time_mean, tribo_loss_max, tribo_loss_min,
    where = (tribo_loss_max > tribo_loss_min), 
    interpolate=False, color="green", alpha=0.05
)

# range
plt.xlim(0,400)
plt.ylim(0,25)
plt.xlabel('Running Time /s', fontsize=15)
plt.ylabel('Train Loss', fontsize=15)
# show
plt.legend(fontsize=15, framealpha=0.5)
plt.show()
#plt.savefig('train_loss.png')
