import numpy as np
import matplotlib.pyplot as plt

##############################################################################################################################
                                             # Import data
##############################################################################################################################

#### No prior

data = np.load("regret_KL_UCB_Transfer_Sim1noprior.npz")
tsaveS1R1   = data["tsave"]     # shape (M,)
RegretS1R1 = data["Regret"]   # shape (R, M)
R, M = RegretS1R1.shape

mean_regretS1R1 = RegretS1R1.mean(axis=0)                     # shape (M,)
sem_regretS1R1  = RegretS1R1.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 1

data = np.load("regret_KL_UCB_Transfer_Sim1prior1.npz")
tsaveS1R2   = data["tsave"]     # shape (M,)
RegretS1R2 = data["Regret"]   # shape (R, M)
R, M = RegretS1R2.shape

mean_regretS1R2 = RegretS1R2.mean(axis=0)                     # shape (M,)
sem_regretS1R2  = RegretS1R2.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 2

data = np.load("regret_KL_UCB_Transfer_Sim1prior2.npz")
tsaveS1R3   = data["tsave"]     # shape (M,)
RegretS1R3 = data["Regret"]   # shape (R, M)
R, M = RegretS1R3.shape

mean_regretS1R3 = RegretS1R3.mean(axis=0)                     # shape (M,)
sem_regretS1R3  = RegretS1R3.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 3

data = np.load("regret_KL_UCB_Transfer_Sim1prior3.npz")
tsaveS1R4   = data["tsave"]     # shape (M,)
RegretS1R4 = data["Regret"]   # shape (R, M)
R, M = RegretS1R4.shape

mean_regretS1R4 = RegretS1R4.mean(axis=0)                     # shape (M,)
sem_regretS1R4  = RegretS1R4.std(axis=0, ddof=1) / np.sqrt(R)  # standard error



#### Prior 4

data = np.load("regret_KL_UCB_Transfer_Sim1prior4.npz")
tsaveS1R4   = data["tsave"]     # shape (M,)
RegretS1R4 = data["Regret"]   # shape (R, M)
R, M = RegretS1R5.shape

mean_regretS1R5 = RegretS1R5.mean(axis=0)                     # shape (M,)
sem_regretS1R5  = RegretS1R5.std(axis=0, ddof=1) / np.sqrt(R)  # standard error

##############################################################################################################################
                                             # Plot Simulation 1
##############################################################################################################################

plt.figure(figsize=(5,3))


#S1R1
plt.fill_between(tsaveS1R1,
                 mean_regretS1R1 - sem_regretS1R1,
                 mean_regretS1R1 + sem_regretS1R1,
                 alpha=0.3,color = "b"
                 )
plt.plot(tsaveS1R1, mean_regretS1R1, lw=1.5, label="No Prior",color = "b",linestyle='-')

#S1R2
plt.fill_between(tsaveS1R2,
                 mean_regretS1R2 - sem_regretS1R2,
                 mean_regretS1R2 + sem_regretS1R2,
                 alpha=0.3,color = "r"
                 )
plt.plot(tsaveS1R2, mean_regretS1R2, lw=1.5, label="Prior 1",color = "r",linestyle='--')

#S1R3
plt.fill_between(tsaveS1R3,
                 mean_regretS1R3 - sem_regretS1R3,
                 mean_regretS1R3 + sem_regretS1R3,
                 alpha=0.3,color = "g"
                 )
plt.plot(tsaveS1R3, mean_regretS1R3, lw=1.5, label="Prior 2",color = "g",linestyle='-.')

#S1R4
plt.fill_between(tsaveS1R4,
                 mean_regretS1R4 - sem_regretS1R4,
                 mean_regretS1R4 + sem_regretS1R4,
                 alpha=0.3,color = "m"
                 )
plt.plot(tsaveS1R4, mean_regretS1R4, lw=1.5, label="Prior 3",color = "m",linestyle=':')

#S1R5
plt.fill_between(tsaveS1R5,
                 mean_regretS1R5 - sem_regretS1R5,
                 mean_regretS1R5 + sem_regretS1R5,
                 alpha=0.3,color = "c"
                 )
plt.plot(tsaveS1R5, mean_regretS1R5, lw=1.5, label="Prior 4",color = "c",linestyle=(0, (5, 1)))




plt.xscale('log')
plt.xlabel('$T$')
plt.ylabel('$R_T$')
plt.legend()
plt.grid(True, which='both', ls='--', alpha=0.4)
plt.tight_layout()
plt.savefig("plot1ACML.pdf", format="pdf")
plt.show()