import os
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from typing import List


def plot_per_token_attention_from_files(prompt: str, tokenizer, attention_dir: str):
    """
    Creates a Gradio interface to view saved per-token attention maps.
    This version would require saving one image per token, which is complex.
    A simpler visualization is to show the saved Bright Ending map.
    """
    # For now, let's create a simple function that just shows the saved BE map
    # A full per-token visualization would require saving more data from the hook.
    
    # Find the BE map from the last step in the highest-res down block
    # Example filename: 'step49_input_blocks.1.1.transformer_blocks.0.attn2.png'
    search_pattern = os.path.join(attention_dir, "step49_*.png")
    map_files = sorted(glob.glob(search_pattern))
    
    if not map_files:
        return gr.update(value=None, label="No saved attention map found for final step.")

    # Display the most relevant map (e.g., from the first down-block)
    image = Image.open(map_files[0])
    
    return gr.update(value=image, label=f"Bright Ending Map for '{prompt}' (Final Step)")

    
def plot_cross_attention_maps(
    prompt: str,
    tokenizer,
    controller,
    from_where: List[str] = ["down", "mid", "up"],
    output_path: str = "attention_maps.png"
):
    """
    Generates and saves a grid of attention maps for each token in a prompt.
    """
    tokens = tokenizer.encode(prompt)
    decoder = {v: k for k, v in tokenizer.get_vocab().items()}
    
    # Get all cross-attention maps from the controller
    all_attn_maps = []
    for key in from_where:
        maps = controller.attention_store.get(f"{key}_cross", [])
        if maps:
            # We take the maps from the last step, assuming it's most relevant
            all_attn_maps.extend(maps)
    
    if not all_attn_maps:
        print("No attention maps found in the controller.")
        return

    # Average across all layers/blocks that were captured
    # Shape of each map: (batch*heads, h*w, num_tokens)
    avg_attn_maps = torch.stack(all_attn_maps, dim=0).mean(dim=0)
    
    # Average across heads
    # New shape: (batch, h*w, num_tokens)
    num_heads = 8 # Assuming 8 heads, standard for SD
    avg_attn_maps = avg_attn_maps.reshape(-1, num_heads, *avg_attn_maps.shape[1:]).mean(dim=1)

    # For simplicity, we'll visualize the first image in the batch
    first_image_maps = avg_attn_maps[0] # Shape: (h*w, num_tokens)

    # Reshape for visualization
    num_tokens = first_image_maps.shape[1]
    hw = int(np.sqrt(first_image_maps.shape[0]))
    maps_grid = first_image_maps.reshape(hw, hw, num_tokens)
    
    # Plotting
    fig, axes = plt.subplots(1, len(tokens), figsize=(len(tokens) * 3, 4))
    if len(tokens) == 1: axes = [axes]
    
    for i, token_id in enumerate(tokens):
        token_text = decoder.get(token_id, "???")
        
        ax = axes[i]
        attn_map = maps_grid[:, :, i].cpu().numpy()
        im = ax.imshow(attn_map, cmap='viridis')
        ax.set_title(f"`{token_text}`")
        ax.axis('off')
        
    fig.colorbar(im, ax=axes.ravel().tolist(), orientation='horizontal', pad=0.05)
    plt.suptitle(f"Cross-Attention Maps for: '{prompt}'", fontsize=16)
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.savefig(output_path)
    plt.close()
    print(f"Saved per-token attention visualization to: {output_path}")

