import numpy as np
import matplotlib.pyplot as plt
import os

# Adjust these parameters to match your saved data file
d = 50
K = d+1  # d + 1
bp = 0.3
sbp = -0.3
precision = 0.0001
n_runs = 100

# Load the data
dir_path = f"./precision_exp_d{d}_K{K}_bp{bp}_sbp{sbp}_precision{precision}_runs{n_runs}"
data_file = "./precision_exp_d50_K51_bp0.3_sbp-0.3_precision0.0001_runs100/raw_data.npy"

# Check if file exists
if not os.path.exists(data_file):
    print(f"Data file not found: {data_file}")
    print("Please ensure you've run the experiment first.")
else:
    # Load the data
    data = np.load(data_file, allow_pickle=True).item()
    
    # Access the arrays
    theta_norm_values = data['theta_norm_values']#[::1]
    lin_means = data['lin_means']#[::1]
    lin_stds = data['lin_stds']#[::1]
    th_means = data['th_means']#[::1]
    th_stds = data['th_stds']#[::1]
    lin_success_rates = data['lin_success_rates']#[::1]
    th_success_rates = data['th_success_rates']#[::1]
    all_results = data['all_results']#[::1]
    parameters = data['parameters']#[::1]
    
    # Create the main plot with error bars
    plt.figure(figsize=(12, 8))
    plt.tick_params(axis='both', which='major', labelsize=14)
    
    plt.errorbar(theta_norm_values, lin_means, yerr=lin_stds, 
                 marker='o', capsize=5, capthick=2, linewidth=2, markersize=8,
                 label='LinTS+MLE', color='blue')
    plt.errorbar(theta_norm_values, th_means, yerr=th_stds, 
                 marker='s', capsize=5, capthick=2, linewidth=2, markersize=8,
                 label='TryHardTS', color='green')
    
    plt.xlabel('θ* Norm', fontsize=18)
    plt.ylabel(f'Mean Rounds to Precision {precision}', fontsize=18)
    plt.title(f'Rounds to Achieve Precision {precision}\n(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', 
              fontsize=16)
    plt.legend(fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    save_path = f'{dir_path}/rounds_comparison_enhanced.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create plot with filled error regions
    plt.figure(figsize=(12, 8))
    plt.tick_params(axis='both', which='major', labelsize=14)
    
    # Plot with shaded error regions
    plt.fill_between(theta_norm_values[::1], 
                     np.maximum(np.array(lin_means)[::1] - np.array(lin_stds)[::1],np.zeros((np.array(lin_means)[::1]).shape)),
                     np.array(lin_means)[::1] + np.array(lin_stds)[::1],
                     alpha=0.3, color='blue')
    plt.fill_between(theta_norm_values[::1],
                     np.maximum(np.array(th_means)[::1] - np.array(th_stds)[::1],np.zeros((np.array(th_means)[::1]).shape)),
                     np.array(th_means)[::1] + np.array(th_stds)[::1],
                     alpha=0.3, color='green')
    
    plt.plot(theta_norm_values[::1], lin_means[::1], linewidth=2, markersize=8, marker='o',
             label='LinTS+MLE', color='blue')
    plt.plot(theta_norm_values[::1], th_means[::1], linewidth=2, markersize=8, marker='s',
             label='TryHardTS', color='green')
    
    plt.xlabel('M', fontsize=18)
    plt.ylabel(f'Mean Rounds to Precision {precision}', fontsize=18)
    plt.title(f'Rounds to Achieve Precision with Error Regions\n(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', 
              fontsize=16)
    plt.legend(fontsize=14, loc='best')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    save_path = f'{dir_path}/rounds_comparison_shaded.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create success rate comparison plot
    plt.figure(figsize=(12, 8))
    plt.tick_params(axis='both', which='major', labelsize=14)
    
    plt.plot(theta_norm_values, lin_success_rates, 'o-', linewidth=2, markersize=8,
             label='LinTS+MLE', color='blue')
    plt.plot(theta_norm_values, th_success_rates, 's-', linewidth=2, markersize=8,
             label='TryHardTS', color='green')
    
    plt.xlabel('θ* Norm', fontsize=18)
    plt.ylabel('Success Rate', fontsize=18)
    plt.title(f'Success Rate (Achieving Precision {precision} within {parameters["max_T"]} rounds)\n'
              f'(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', fontsize=16)
    plt.legend(fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.ylim([0, 1.05])
    plt.tight_layout()
    
    save_path = f'{dir_path}/success_rates_enhanced.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create relative performance plot
    plt.figure(figsize=(12, 8))
    plt.tick_params(axis='both', which='major', labelsize=14)
    
    # Calculate relative performance (ratio to best at each theta_norm)
    ratios_lin = []
    ratios_th = []
    for i in range(len(theta_norm_values)):
        best_at_theta = min(lin_means[i], th_means[i])
        ratios_lin.append(lin_means[i] / best_at_theta)
        ratios_th.append(th_means[i] / best_at_theta)
    
    plt.plot(theta_norm_values, ratios_lin, 'o-', linewidth=2, markersize=8,
             label='LinTS+MLE', color='blue')
    plt.plot(theta_norm_values, ratios_th, 's-', linewidth=2, markersize=8,
             label='TryHardTS', color='green')
    plt.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='Best Performance')
    
    plt.xlabel('θ* Norm', fontsize=18)
    plt.ylabel('Relative Performance (Ratio to Best)', fontsize=18)
    plt.title(f'Relative Algorithm Performance\n(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', 
              fontsize=16)
    plt.legend(fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    save_path = f'{dir_path}/relative_performance.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print detailed statistics
    print("\nAlgorithm Comparison Summary:")
    print("=" * 80)
    print(f"{'θ* Norm':<12} {'LinTS Mean':<15} {'LinTS Std':<12} {'TryHardTS Mean':<15} {'TryHardTS Std':<12}")
    print("-" * 80)
    for i, theta in enumerate(theta_norm_values):
        print(f"{theta:<12} {lin_means[i]:<15.1f} {lin_stds[i]:<12.1f} {th_means[i]:<15.1f} {th_stds[i]:<12.1f}")
    
    # Find best theta_norm for each algorithm
    best_lin_idx = np.argmin(lin_means)
    best_th_idx = np.argmin(th_means)
    
    print("\nBest Performance (Lowest Mean Rounds):")
    print(f"LinTS+MLE: θ*_norm={theta_norm_values[best_lin_idx]}, Mean={lin_means[best_lin_idx]:.1f} rounds")
    print(f"TryHardTS: θ*_norm={theta_norm_values[best_th_idx]}, Mean={th_means[best_th_idx]:.1f} rounds")
    
    # Calculate relative performance statistics
    print("\nRelative Performance (compared to best algorithm at each θ* norm):")
    print("=" * 80)
    print(f"{'θ* Norm':<12} {'Best Alg':<15} {'LinTS Ratio':<15} {'TryHardTS Ratio':<15} {'Winner Margin':<15}")
    print("-" * 80)
    
    for i, theta in enumerate(theta_norm_values):
        best_at_theta = min(lin_means[i], th_means[i])
        best_alg = 'LinTS+MLE' if lin_means[i] < th_means[i] else 'TryHardTS'
        lin_ratio = lin_means[i] / best_at_theta
        th_ratio = th_means[i] / best_at_theta
        margin = abs(lin_means[i] - th_means[i]) / best_at_theta * 100  # percentage margin
        print(f"{theta:<12} {best_alg:<15} {lin_ratio:<15.3f} {th_ratio:<15.3f} {margin:<15.1f}%")
    
    # Overall comparison
    print("\n" + "=" * 80)
    print("Overall Algorithm Comparison:")
    print("-" * 80)
    
    lin_wins = sum(1 for i in range(len(theta_norm_values)) if lin_means[i] < th_means[i])
    th_wins = len(theta_norm_values) - lin_wins
    
    print(f"LinTS+MLE performs better in {lin_wins}/{len(theta_norm_values)} settings")
    print(f"TryHardTS performs better in {th_wins}/{len(theta_norm_values)} settings")
    
    # Average performance across all theta_norms
    avg_lin = np.mean(lin_means)
    avg_th = np.mean(th_means)
    print(f"\nAverage rounds across all θ* norms:")
    print(f"LinTS+MLE: {avg_lin:.1f} ± {np.mean(lin_stds):.1f}")
    print(f"TryHardTS: {avg_th:.1f} ± {np.mean(th_stds):.1f}")
    
    # Success rate statistics
    print(f"\nAverage success rates:")
    print(f"LinTS+MLE: {np.mean(lin_success_rates)*100:.1f}%")
    print(f"TryHardTS: {np.mean(th_success_rates)*100:.1f}%")
    
    # Create a combined subplot figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'Comprehensive Algorithm Comparison\n(d={d}, K={K}, bp={bp}, sbp={sbp}, {n_runs} runs)', 
                 fontsize=18)
    
    # Subplot 1: Mean rounds with error bars
    axes[0,0].errorbar(theta_norm_values, lin_means, yerr=lin_stds, 
                       marker='o', capsize=5, linewidth=2, markersize=8,
                       label='LinTS+MLE', color='blue')
    axes[0,0].errorbar(theta_norm_values, th_means, yerr=th_stds, 
                       marker='s', capsize=5, linewidth=2, markersize=8,
                       label='TryHardTS', color='green')
    axes[0,0].set_xlabel('θ* Norm', fontsize=14)
    axes[0,0].set_ylabel(f'Mean Rounds to Precision {precision}', fontsize=14)
    axes[0,0].set_title('Mean Rounds Comparison', fontsize=14)
    axes[0,0].legend(fontsize=12)
    axes[0,0].grid(True, alpha=0.3)
    
    # Subplot 2: Success rates
    axes[0,1].plot(theta_norm_values, lin_success_rates, 'o-', linewidth=2, markersize=8,
                   label='LinTS+MLE', color='blue')
    axes[0,1].plot(theta_norm_values, th_success_rates, 's-', linewidth=2, markersize=8,
                   label='TryHardTS', color='green')
    axes[0,1].set_xlabel('θ* Norm', fontsize=14)
    axes[0,1].set_ylabel('Success Rate', fontsize=14)
    axes[0,1].set_title('Success Rate Comparison', fontsize=14)
    axes[0,1].legend(fontsize=12)
    axes[0,1].grid(True, alpha=0.3)
    axes[0,1].set_ylim([0, 1.05])
    
    # Subplot 3: Relative performance
    axes[1,0].plot(theta_norm_values, ratios_lin, 'o-', linewidth=2, markersize=8,
                   label='LinTS+MLE', color='blue')
    axes[1,0].plot(theta_norm_values, ratios_th, 's-', linewidth=2, markersize=8,
                   label='TryHardTS', color='green')
    axes[1,0].axhline(y=1.0, color='black', linestyle='--', alpha=0.5)
    axes[1,0].set_xlabel('θ* Norm', fontsize=14)
    axes[1,0].set_ylabel('Ratio to Best', fontsize=14)
    axes[1,0].set_title('Relative Performance', fontsize=14)
    axes[1,0].legend(fontsize=12)
    axes[1,0].grid(True, alpha=0.3)
    
    # Subplot 4: Performance difference
    performance_diff = np.array(th_means) - np.array(lin_means)
    colors = ['blue' if diff > 0 else 'green' for diff in performance_diff]
    axes[1,1].bar(theta_norm_values, performance_diff, color=colors, alpha=0.7)
    axes[1,1].axhline(y=0, color='black', linestyle='-', alpha=0.5)
    axes[1,1].set_xlabel('θ* Norm', fontsize=14)
    axes[1,1].set_ylabel('TryHardTS - LinTS (rounds)', fontsize=14)
    axes[1,1].set_title('Performance Difference\n(Positive = LinTS better, Negative = TryHardTS better)', 
                        fontsize=14)
    axes[1,1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    save_path = f'{dir_path}/comprehensive_comparison.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nAll plots saved to: {dir_path}/")
    print("Generated files:")
    print(f"  - rounds_comparison_enhanced.png")
    print(f"  - rounds_comparison_shaded.png")
    print(f"  - success_rates_enhanced.png")
    print(f"  - relative_performance.png")
    print(f"  - comprehensive_comparison.png")