

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import time

start_time = time.time()

# Load
cumulative_accuracies_feasible = np.load('cumulative_accuracies_feasible.npy')
cumulative_accuracies_bfaips = np.load('cumulative_accuracies_bfaips.npy')
cumulative_accuracies_ttts = np.load('cumulative_accuracies_ttts.npy')
cumulative_accuracies_optimal = np.load('cumulative_accuracies_optimal.npy')
cumulative_accuracies_bfaips_ew = np.load('cumulative_accuracies_bfaips_ew.npy')




from scipy.interpolate import interp1d



# 原始维度
n_rows, n_cols = cumulative_accuracies_bfaips_ew.shape

# 原始横坐标：0, 1, ..., n_cols - 1
x_orig = np.arange(n_cols)

# 插值横坐标：0, 0.5, 1, 1.5, ..., n_cols - 0.5
x_interp = np.linspace(0, n_cols - 1, num=2 * n_cols)

# 用线性插值，每一行单独插
cumulative_accuracies_stretched = np.array([
    interp1d(x_orig, row, kind='linear')(x_interp[:n_cols])  # 取前一半
    for row in cumulative_accuracies_bfaips_ew
])



cumulative_accuracies_bfaips_ew=cumulative_accuracies_stretched









# 检查数据是否成功加载
print("Feasible Accuracies:", cumulative_accuracies_feasible)
print("BFAIPS Accuracies:", cumulative_accuracies_bfaips)
print("TTTS Accuracies:", cumulative_accuracies_ttts)
print("Optimal Accuracies:", cumulative_accuracies_optimal)
num_runs = 50
T = 2000


# Calculate mean and standard error for all algorithms
mean_accuracy_feasible = cumulative_accuracies_feasible.mean(axis=0)
std_error_feasible = cumulative_accuracies_feasible.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_bfaips = cumulative_accuracies_bfaips.mean(axis=0)
std_error_bfaips = cumulative_accuracies_bfaips.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_ttts = cumulative_accuracies_ttts.mean(axis=0)
std_error_ttts = cumulative_accuracies_ttts.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_optimal = cumulative_accuracies_optimal.mean(axis=0)
std_error_optimal = cumulative_accuracies_optimal.std(axis=0) / np.sqrt(num_runs)

mean_accuracy_bfaips_ew = cumulative_accuracies_bfaips_ew.mean(axis=0)
std_error_bfaips_ew = cumulative_accuracies_bfaips_ew.std(axis=0) / np.sqrt(num_runs)

# Plot results
plt.figure(figsize=(10, 6))

# Plot for Feasible Linear Thompson Sampling
plt.plot(range(1, T + 1), mean_accuracy_feasible, label="LinTS (Feasible)", color="green")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_feasible - 1 * std_error_feasible,
    mean_accuracy_feasible + 1 * std_error_feasible,
    color="green",
    alpha=0.2
)

# Plot for BFAIPS
plt.plot(range(1, T + 1), mean_accuracy_bfaips, label="BLFAIPS", color="blue")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_bfaips - 1 * std_error_bfaips,
    mean_accuracy_bfaips + 1 * std_error_bfaips,
    color="blue",
    alpha=0.2
)

# Plot for TTTS
plt.plot(range(1, T + 1), mean_accuracy_ttts, label="LinTTTS (with Optimal beta)", color="orange")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_ttts - 1 * std_error_ttts,
    mean_accuracy_ttts + 1 * std_error_ttts,
    color="orange",
    alpha=0.2
)

# Plot for Oracle
plt.plot(range(1, T + 1), mean_accuracy_optimal, label="Oracle", color="red")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_optimal - 1 * std_error_optimal,
    mean_accuracy_optimal + 1 * std_error_optimal,
    color="red",
    alpha=0.2
)

# Plot for PEPS
plt.plot(range(1, T + 1), mean_accuracy_bfaips_ew, label="PEPS", color="black")
plt.fill_between(
    range(1, T + 1),
    mean_accuracy_bfaips_ew - 1 * std_error_bfaips_ew,
    mean_accuracy_bfaips_ew + 1 * std_error_bfaips_ew,
    color="blue",
    alpha=0.2
)


plt.xlim((0,2000))
plt.ylim((0,1))
#plt.ylim((0,7))
plt.xticks([0,250,500,750,1000,1250,1500,1750,2000], fontsize=12)
plt.yticks([0.0,0.2,0.4,0.6,0.8,1.0], fontsize=12)

plt.xlabel("Time", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
#plt.title("Accuracy of linear best arm identification without constraints")
plt.legend(loc='lower right',fontsize=12)

plt.grid(True)
plt.savefig('Comparison_with_PEPS.png', format='png', dpi=300)
plt.savefig("Comparison_with_PEPS.pdf", format="pdf", bbox_inches='tight')

plt.show()


