import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import glob

def load_ablation_results(results_dir):
    """Load ablation study results from directory"""
    summary_file = os.path.join(results_dir, 'ablation_summary.json')
    
    if not os.path.exists(summary_file):
        raise FileNotFoundError(f"Summary file not found: {summary_file}")
    
    with open(summary_file, 'r') as f:
        data = json.load(f)
    
    return data

def extract_params_and_metrics(all_results):
    """Extract alpha, beta, lr, and accuracy metrics from results"""
    alphas = []
    betas = []
    lrs = []
    test_accs = []
    test_stds = []
    train_accs = []
    train_stds = []
    
    for config_name, config_data in all_results.items():
        if 'statistics' not in config_data or not config_data['statistics']:
            continue
        
        params = config_data['params']
        stats = config_data['statistics']
        
        # Extract parameters
        alpha = params.get('alpha', None)
        beta = params.get('beta', None)
        lr = params.get('lr', None)
        
        if alpha is None or beta is None or lr is None:
            continue
        
        alphas.append(alpha)
        betas.append(beta)
        lrs.append(lr)
        test_accs.append(stats['mean_best_acc'])
        test_stds.append(stats['std_best_acc'])
        train_accs.append(stats['mean_final_acc'])
        train_stds.append(stats['std_final_acc'])
    
    return {
        'alphas': np.array(alphas),
        'betas': np.array(betas),
        'lrs': np.array(lrs),
        'test_accs': np.array(test_accs),
        'test_stds': np.array(test_stds),
        'train_accs': np.array(train_accs),
        'train_stds': np.array(train_stds)
    }

def create_heatmap_grid(alphas, betas, values):
    """Create a grid for heatmap from scattered points"""
    unique_alphas = np.sort(np.unique(alphas))
    unique_betas = np.sort(np.unique(betas))
    
    # Create grid with alpha on y-axis, beta on x-axis
    grid = np.full((len(unique_alphas), len(unique_betas)), np.nan)
    
    # Fill grid
    for i in range(len(alphas)):
        alpha_idx = np.where(unique_alphas == alphas[i])[0][0]
        beta_idx = np.where(unique_betas == betas[i])[0][0]
        grid[alpha_idx, beta_idx] = values[i]
    
    return grid, unique_alphas, unique_betas

def plot_heatmaps(results_dir):
    """Create three separate heatmap figures"""
    
    # Load data
    data = load_ablation_results(results_dir)
    all_results = data['all_results']
    metrics = extract_params_and_metrics(all_results)
    
    # Create grids for heatmaps
    lr_grid, alphas, betas = create_heatmap_grid(metrics['alphas'], metrics['betas'], metrics['lrs'])
    test_acc_grid, _, _ = create_heatmap_grid(metrics['alphas'], metrics['betas'], metrics['test_accs'])
    train_acc_grid, _, _ = create_heatmap_grid(metrics['alphas'], metrics['betas'], metrics['train_accs'])
    
    # Use same colormap for all three figures
    data = np.random.rand(20, 20)

    # original colormap
    cmap = plt.cm.viridis

    colors = cmap(np.linspace(0, 1, 256))
    colors = 0.7 * colors + 0.3 * np.ones_like(colors)  
    common_cmap = mcolors.ListedColormap(colors)   
    
    
    # Figure 1: Learning Rate Heatmap
    fig1, ax1 = plt.subplots(figsize=(10, 6))
    # Use robust scaling: set vmin/vmax to percentiles to make colors more visible
    lr_min = np.nanpercentile(lr_grid, 5)
    lr_max = np.nanpercentile(lr_grid, 95)
    im1 = ax1.imshow(lr_grid, cmap=common_cmap, aspect='auto', origin='lower',
                     vmin=lr_min, vmax=lr_max)
    ax1.set_xticks(range(len(betas)))
    ax1.set_yticks(range(len(alphas)))
    ax1.set_xticklabels([f'{b:.3f}' for b in betas], rotation=45, ha='right')
    ax1.set_yticklabels([f'{a:.3f}' for a in alphas])
    ax1.set_xlabel(r'Beta ($\beta$)', fontsize=14, fontweight='bold')
    ax1.set_ylabel(r'Alpha ($\alpha$)', fontsize=14, fontweight='bold')
    ax1.set_title(r'Learning Rate vs $\alpha$ and $\beta$', fontsize=16, fontweight='bold', pad=20)
    
    # Add colorbar
    cbar1 = plt.colorbar(im1, ax=ax1)
    cbar1.set_label('Learning Rate', fontsize=16, fontweight='bold')
    
    # Add values as text
    for i in range(len(alphas)):
        for j in range(len(betas)):
            if not np.isnan(lr_grid[i, j]):
                text = ax1.text(j, i, f'{lr_grid[i, j]:.2e}',
                              ha="center", va="center", color="white", fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    output1 = os.path.join(results_dir, 'heatmap_learning_rate.pdf')
    plt.savefig(output1, dpi=300, bbox_inches='tight')
    print(f"Learning Rate heatmap saved to: {output1}")
    plt.show()
    
    # Figure 2: Test Accuracy Heatmap
    fig2, ax2 = plt.subplots(figsize=(10, 6))
    # Use robust scaling for better color contrast
    test_min = np.nanpercentile(test_acc_grid, 5)
    test_max = np.nanpercentile(test_acc_grid, 95)
    im2 = ax2.imshow(test_acc_grid, cmap=common_cmap, aspect='auto', origin='lower', 
                     vmin=test_min, vmax=test_max)
    ax2.set_xticks(range(len(betas)))
    ax2.set_yticks(range(len(alphas)))
    ax2.set_xticklabels([f'{b:.3f}' for b in betas], rotation=45, ha='right')
    ax2.set_yticklabels([f'{a:.3f}' for a in alphas])
    ax2.set_xlabel(r'Beta ($\beta$)', fontsize=14, fontweight='bold')
    ax2.set_ylabel(r'Alpha ($\alpha$)', fontsize=14, fontweight='bold')
    ax2.set_title(r'Test Accuracy vs $\alpha$ and $\beta$', fontsize=16, fontweight='bold', pad=20)
    
    # Add colorbar
    cbar2 = plt.colorbar(im2, ax=ax2)
    cbar2.set_label('Test Accuracy (%)', fontsize=16, fontweight='bold')
    
    # Add values as text
    for i in range(len(alphas)):
        for j in range(len(betas)):
            if not np.isnan(test_acc_grid[i, j]):
                text = ax2.text(j, i, f'{test_acc_grid[i, j]:.2f}',
                              ha="center", va="center", color="white", fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    output2 = os.path.join(results_dir, 'heatmap_test_accuracy.pdf')
    plt.savefig(output2, dpi=300, bbox_inches='tight')
    print(f" Test Accuracy heatmap saved to: {output2}")
    plt.show()
    
    # Figure 3: Training Accuracy Heatmap
    fig3, ax3 = plt.subplots(figsize=(10, 6))
    # Use robust scaling for better color contrast
    train_min = np.nanpercentile(train_acc_grid, 5)
    train_max = np.nanpercentile(train_acc_grid, 95)
    im3 = ax3.imshow(train_acc_grid, cmap=common_cmap, aspect='auto', origin='lower',
                     vmin=train_min, vmax=train_max)
    ax3.set_xticks(range(len(betas)))
    ax3.set_yticks(range(len(alphas)))
    ax3.set_xticklabels([f'{b:.3f}' for b in betas], rotation=45, ha='right')
    ax3.set_yticklabels([f'{a:.3f}' for a in alphas])
    ax3.set_xlabel(r'Beta ($\beta$)', fontsize=14, fontweight='bold')
    ax3.set_ylabel(r'Alpha ($\alpha$)', fontsize=14, fontweight='bold')
    ax3.set_title(r'Training Accuracy vs $\alpha$ and $\beta$', fontsize=16, fontweight='bold', pad=20)
    
    # Add colorbar
    cbar3 = plt.colorbar(im3, ax=ax3)
    cbar3.set_label('Training Accuracy (%)', fontsize=16, fontweight='bold')
    
    # Add values as text
    for i in range(len(alphas)):
        for j in range(len(betas)):
            if not np.isnan(train_acc_grid[i, j]):
                text = ax3.text(j, i, f'{train_acc_grid[i, j]:.2f}',
                              ha="center", va="center", color="white", fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    output3 = os.path.join(results_dir, 'heatmap_training_accuracy.pdf')
    plt.savefig(output3, dpi=300, bbox_inches='tight')
    print(f" Training Accuracy heatmap saved to: {output3}")
    plt.show()
    

# ---------------------------
# Main execution
# ---------------------------
if __name__ == "__main__":
    # Find the most recent ablation results directory
    ablation_dirs = glob.glob("ablation_results_*")
    
    if not ablation_dirs:
        print(" No ablation results directories found!")
        print("Please run the ablation study first.")
    else:
        # Use the most recent directory
        latest_dir = sorted(ablation_dirs)[-1]
        print(f"Loading results from: {latest_dir}")
        plot_heatmaps(latest_dir)