import os
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

def generate_attention_heatmap(attn_matrix, layer_idx, id='', output_dir='./plots'):
    plt.rcParams['axes.unicode_minus'] = False  # Fix negative sign display

    os.makedirs(output_dir, exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Plot heatmap
    print(f'Size: {attn_matrix.shape}')
    print(f'Attention matrix max: {attn_matrix.max().item()}, min: {attn_matrix.min().item()}')
    im = ax.imshow(attn_matrix, cmap='viridis', vmin=0, vmax=0.35)
    
    ax.set_xticklabels([])  # Remove x-axis labels
    ax.set_yticklabels([])  # Remove y-axis labels
    ax.tick_params(axis='both', which='both', length=0)  # Remove tick marks
    
    # Add grid lines
    n = attn_matrix.shape[0]
    for i in range(n + 1):

        if i % max(1, n // 18) == 0:            
            ax.axhline(y=i-0.5, color='white', linewidth=0.3, alpha=0.5, zorder=2)
            ax.axvline(x=i-0.5, color='white', linewidth=0.3, alpha=0.5, zorder=2)

        if i % max(1, n // 6) == 0:
            ax.axhline(y=i-0.5, color='white', linewidth=0.8, alpha=0.8, zorder=2)
            ax.axvline(x=i-0.5, color='white', linewidth=0.8, alpha=0.8, zorder=2)
    
    # Add title
    #layer_type = "First Layer" if layer_idx == 0 else "Last Layer"
    #ax.set_title(f"{layer_type} Attention Heatmap (Layer {layer_idx+1})")
    
    # Add colorbar
    divider = make_axes_locatable(ax)
    cbar_ax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cbar_ax)
    
    # Set colorbar ticks based on max/min values of attn_matrix, keep one decimal place
    min_val = round(attn_matrix.min().item(), 1)
    max_val = round(attn_matrix.max().item(), 1)
    # Generate 5 evenly distributed tick values
    ticks = [min_val + (max_val - min_val) * i / 4 for i in range(5)]
    tick_labels = [f'{tick:.1f}' for tick in ticks]
    
    
    cbar.set_ticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35])
    #cbar.set_ticks(ticks)
    cbar.set_ticklabels(['0','0.05','0.10','0.15','0.20','0.25','0.30','0.35'])
    # Set colorbar tick label font size to 32
    cbar.ax.tick_params(labelsize=30)
    #cbar.set_ticklabels(tick_labels)
    
    # Save heatmap
    filename = f'{id}_layer_{layer_idx+1}_attention_heatmap.png'
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300, pad_inches=0.1)

    filename = f'{id}_layer_{layer_idx+1}_attention_heatmap.pdf'
    filepath = os.path.join(output_dir, filename)
    # Set font to avoid Type 3 fonts
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.savefig(filepath, bbox_inches='tight', dpi=300, pad_inches=0.1)
    
    plt.close(fig)
    print(f"Heatmap saved: {filepath}")
