import numpy as np
import matplotlib.pyplot as plt

# Load the data with LinTS included
data = np.load('precision_raw_data_cumu_d8_Delta5.0_sigma5.0_runs100.npy', allow_pickle=True).item()

# Access the arrays
K_values = data['K_values']
print(K_values)
ts_means = data['ts_means']
ts_stds = data['ts_stds']
ug_means = data['ug_means']
ug_stds = data['ug_stds']
cumuts_means = data['cumuts_means']  # Access LinTS means
cumuts_stds = data['cumuts_stds']    # Access LinTS stds
ts_success_rates = data['ts_success_rates']
ug_success_rates = data['ug_success_rates']
cumuts_success_rates = data['cumuts_success_rates']  # Access LinTS success rates
all_results = data['all_results']

# Create the log-log plot with all three algorithms
plt.tick_params(axis='both', which='major', labelsize=15)
plt.figure(figsize=(10, 6))
plt.plot(np.log2(K_values), np.log2(ts_means), 'o-', label='SimpleLinTS', linewidth=2, markersize=8)
plt.plot(np.log2(K_values), np.log2(ug_means), 's-', label='Uniform Exploration', linewidth=2, markersize=8)
plt.plot(np.log2(K_values), np.log2(cumuts_means), '^-', label='CumuLinTS', linewidth=2, markersize=8)

plt.xlabel('log₂(K) - Log of K values', fontsize=20)
plt.ylabel('log₂(T) - Log of Mean Rounds to Precision', fontsize=18)
plt.title('Log-Log Plot: Rounds to Precision vs Number of Arms', fontsize=14)
plt.legend(fontsize=14)
plt.grid(True, alpha=0.3)
plt.savefig('precision_loglog_plot_with_cumuts.png', dpi=300, bbox_inches='tight')
plt.show()

# Also create a version with error bars in log-log space
plt.figure(figsize=(10, 6))

# Calculate error bars in log space (approximation)
# log(mean ± std) ≈ log(mean) ± std/mean
ts_log_err_lower = np.log2(np.maximum(ts_means - ts_stds, 1))
ts_log_err_upper = np.log2(ts_means + ts_stds)
ug_log_err_lower = np.log2(np.maximum(ug_means - ug_stds, 1))
ug_log_err_upper = np.log2(ug_means + ug_stds)
cumuts_log_err_lower = np.log2(np.maximum(cumuts_means - cumuts_stds, 1))
cumuts_log_err_upper = np.log2(cumuts_means + cumuts_stds)

# Plot with error regions
plt.tick_params(axis='both', which='major', labelsize=15)
plt.fill_between(np.log2(K_values), ts_log_err_lower, ts_log_err_upper, alpha=0.2, color='C0')
plt.fill_between(np.log2(K_values), ug_log_err_lower, ug_log_err_upper, alpha=0.2, color='C1')
plt.fill_between(np.log2(K_values), cumuts_log_err_lower, cumuts_log_err_upper, alpha=0.2, color='C2')

plt.plot(np.log2(K_values), np.log2(ts_means), 'o-', label='CumuLinTS', linewidth=2, markersize=8)
plt.plot(np.log2(K_values), np.log2(ug_means), 's-', label='Uniform Exploration', linewidth=2, markersize=8)
plt.plot(np.log2(K_values), np.log2(cumuts_means), '^-', label='CumuLinTS', linewidth=2, markersize=8)

plt.xlabel('log₂(K) - Log of K values', fontsize=20)
plt.ylabel('log₂(T) - Log of Mean Rounds to Precision', fontsize=18)
plt.title('Log-Log Plot with Error Regions: Rounds to Precision vs Number of Arms', fontsize=14)
plt.legend(fontsize=14)
plt.grid(True, alpha=0.3)
plt.savefig('precision_loglog_plot_with_cumuts_errors.png', dpi=300, bbox_inches='tight')
plt.show()

# Print some statistics for comparison
print("\nAlgorithm Comparison Summary:")
print("=" * 60)
print(f"{'K':<10} {'TS Mean':<12} {'UG Mean':<12} {'LinTS Mean':<12}")
print("-" * 60)
for i, k in enumerate(K_values):
    print(f"{k:<10} {ts_means[i]:<12.1f} {ug_means[i]:<12.1f} {cumuts_means[i]:<12.1f}")

# Find best K for each algorithm
best_ts_idx = np.argmin(ts_means)
best_ug_idx = np.argmin(ug_means)
best_cumuts_idx = np.argmin(cumuts_means)

print("\nBest Performance (Lowest Mean Rounds):")
print(f"SimpLinTS: K={K_values[best_ts_idx]}, Mean={ts_means[best_ts_idx]:.1f} rounds")
print(f"Uniform Exploration: K={K_values[best_ug_idx]}, Mean={ug_means[best_ug_idx]:.1f} rounds")
print(f"CumuLinTS: K={K_values[best_cumuts_idx]}, Mean={cumuts_means[best_cumuts_idx]:.1f} rounds")

# Calculate relative performance
print("\nRelative Performance (compared to best in each K):")
print("=" * 60)
print(f"{'K':<10} {'Best':<15} {'TS Ratio':<12} {'UG Ratio':<12} {'LinTS Ratio':<12}")
print("-" * 60)
for i, k in enumerate(K_values):
    best_at_k = min(ts_means[i], ug_means[i], cumuts_means[i])
    best_alg = ['TS', 'UG', 'LinTS'][[ts_means[i], ug_means[i], cumuts_means[i]].index(best_at_k)]
    ts_ratio = ts_means[i] / best_at_k
    ug_ratio = ug_means[i] / best_at_k
    cumuts_ratio = cumuts_means[i] / best_at_k
    print(f"{k:<10} {best_alg:<15} {ts_ratio:<12.3f} {ug_ratio:<12.3f} {cumuts_ratio:<12.3f}")