
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import time

start_time = time.time()

# Load
acc_ts = np.load('acc_ts.npy')
acc_bfaips = np.load('acc_bfaips.npy')
acc_ttts = np.load('acc_ttts.npy')

# 取出最后一列
last_column = acc_bfaips[:, -1]

# 保留最后一列 >= 0.5 的行
acc_bfaips_filtered = acc_bfaips[last_column >= 0.5]
num_runs = 50
T = 1000


# Calculate mean and standard error for all algorithms
mean_accuracy_feasible = acc_ts.mean(axis=0)
std_error_feasible = acc_ts.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_bfaips = acc_bfaips_filtered.mean(axis=0)
std_error_bfaips = acc_bfaips_filtered.std(axis=0) / np.sqrt(26)

mean_accuracy_ttts = acc_ttts.mean(axis=0)
std_error_ttts = acc_ttts.std(axis=0) / np.sqrt(num_runs)



# Plot results
plt.figure(figsize=(10, 6))

# Plot for Feasible Linear Thompson Sampling
plt.plot(range(1, T + 1), mean_accuracy_feasible, label="LinTS (Feasible)", color="green")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_feasible - 1 * std_error_feasible,
    mean_accuracy_feasible + 1 * std_error_feasible,
    color="green",
    alpha=0.2
)

# Plot for BFAIPS
plt.plot(range(1, T + 1), mean_accuracy_bfaips, label="BLFAIPS", color="blue")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_bfaips - 1 * std_error_bfaips,
    mean_accuracy_bfaips + 1 * std_error_bfaips,
    color="blue",
    alpha=0.2
)

# Plot for TTTS
plt.plot(range(1, T + 1), mean_accuracy_ttts, label="LinTTTS (with beta=0.5)", color="orange")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_ttts - 1 * std_error_ttts,
    mean_accuracy_ttts + 1 * std_error_ttts,
    color="orange",
    alpha=0.2
)


plt.xlim((0,1000))
plt.ylim((0,1))
#plt.ylim((0,7))
plt.xticks([0,250,500,750,1000], fontsize=12)
plt.yticks([0.0,0.2,0.4,0.6,0.8,1.0], fontsize=12)

plt.xlabel("Time", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Accuracy vs time (MovieLens dataset and IMDb ratings) : Linear TS (Feasible) vs BLFAIPS vs LinTTTS (with beta=0.5)")
plt.legend(loc='lower right',fontsize=12)

plt.grid(True)
plt.savefig('./Movielens.png', format='png', dpi=300)
plt.show()


