import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import os
from tqdm.notebook import tqdm
from finetuning_utils import get_attention_hook, remove_all_forward_hooks, split_attention_heads
from PIL import Image


def generate_visualisations(model, concept, epoch, folder):
    print(f"========= Attention overlays visualisation generation for epoch: {epoch} has started! =======")
    pipe = model
    print("Cleaning any existing hooks or gradient requireing parameters...")
    #Remove existing attention hooks
    remove_all_forward_hooks(pipe.unet)

    for param in pipe.unet.parameters():
        param.requires_grad = False

    #Setting the attention hooks on the attention layer 

    attention_activations = {}
    num_of_attention_heads={}
    # Register hooks ONLY on the full attn2 modules
    for name, module in pipe.unet.named_modules():
        if name.endswith("attn2"):  # careful: only full cross-attention modules
        # if "attn2" in name:
            module.register_forward_hook(get_attention_hook(name, attention_activations))
            num_of_attention_heads[name]=module.heads

    #Generating a sample image
    prompt = concept
    # generator = torch.manual_seed(42)  # Fix seed for reproducibility
    gen_image = pipe(prompt, num_inference_steps=50)

    #Splitting the activations across the attention heads
    print("Computing the attention heads activations!")
    attention_heads_activations=split_attention_heads(attention_activations, num_of_attention_heads)
    print("Generating the visualizations for attention overlays")
    plot_all_attention_head_overlays(gen_image, attention_heads_activations, num_of_attention_heads, concept, epoch, folder=folder)
    return


def generate_intermediate_visaulisations(concepts, pipe, epoch, step, path, n=1):

    output_dir = f"{path}/visual_evaluations"
    os.makedirs(output_dir, exist_ok=True)
    n_rows = len(concepts)
    n_cols = n
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    
    # Ensure axs is always 2D (important if 1 row/col)
    if n_rows == 1 and n_cols == 1:
        axs = [[axs]]
    elif n_rows == 1:
        axs = [axs]
    elif n_cols == 1:
        axs = [[ax] for ax in axs]

    for j in tqdm(range(n_cols), desc=f"Generating images for {concepts}"):
        # remove_all_forward_hooks(pipe.unet)
        for i, prompt in enumerate(concepts):
            # Generate image for this (prompt, scale) pair
            result = pipe(prompt, num_inference_steps=50)
            image = result.images[0]
            
            ax = axs[i][j]
            ax.imshow(image)
            ax.axis('off')
            
            # if i == 0:
            #     ax.set_title(f"Scale: {scale}", fontsize=12)
            if j == 0:
                ax.set_ylabel(prompt, fontsize=12)

    plt.tight_layout()
    plt.savefig(f"{output_dir}/epoch_{epoch}_step_{step}.png")
    plt.show()
    return 


def visualise_FID(fid_scores, epoch, concepts, folder_name):
    epoch_path = f"{folder_name}/epoch_{epoch}/"
    plt.figure(figsize=(10, 5))
    for concept in concepts:
        # Plot FID values at each step of the given epoch for each concept
        step_fid_scores = fid_scores[epoch][concept]
        plt.plot(range(len(step_fid_scores)), step_fid_scores, label=f"{concept}")
    
    plt.xlabel("Steps")
    plt.ylabel("Average FID Score")
    plt.title(f"FID Scores Over Steps: Epoch {epoch}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{epoch_path}/fid_scores.png")
    plt.show()
    
def visualise_CLIP_scores(clip_scores, epoch, concepts, folder_name):
    epoch_path = f"{folder_name}/epoch_{epoch}/"
    plt.figure(figsize=(10, 5))
    for concept in concepts:
        # Plot FID values at each step of the given epoch for each concept
        step_clip_scores = clip_scores[epoch][concept]
        plt.plot(range(len(step_clip_scores)), step_clip_scores, label=f"{concept}")
    
    plt.xlabel("Steps")
    plt.ylabel("Average CLIP Score")
    plt.title(f"CLIP Scores Over Steps: Epoch {epoch}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{epoch_path}/clip_scores.png")
    plt.show()
    
def plot_variation_of_heads(heads_info, epoch, folder_name):
    epoch_path = f"{folder_name}/epoch_{epoch}/"
    plt.figure(figsize=(10, 5))
    plt.plot(heads_info[epoch]["Query_heads"], label="No. of query heads")
    plt.plot(heads_info[epoch]["Value_heads"], label="No. of value heads")
    plt.plot(heads_info[epoch]["Key_heads"], label="No. of key heads")
    plt.xlabel("Steps")
    plt.ylabel("Number of concept neurons")
    plt.title(f"Number of Concept Neurons vs Steps- Epoch: {epoch}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{epoch_path}/heads_variation_plot.png")
    plt.show()

def visualise_step_losses(losses, epoch, folder_name, loss_type="Gradient_based"):
    epoch_path = f"{folder_name}/epoch_{epoch}/"
    plt.figure(figsize=(10, 5))
    plt.plot(losses[epoch], label=f"{loss_type} Loss")
    plt.xlabel("Steps")
    plt.ylabel(f"{loss_type} Loss Value")
    plt.title(f"{loss_type} Loss Over Steps: Epoch {epoch}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{epoch_path}/{loss_type}_loss_plot.png")
    plt.show()
    
    
def plot_attention_head_heatmap(
    gen_image,
    split_activations,
    layer_name,
    head_index,
    upsample_factor=8,
):

    head_key = f"head.{head_index}"
    activation = split_activations[layer_name][head_key]
    # print("activation is here!")
    activation = activation.mean(dim=-1)  # (tokens,)

    tokens = activation.shape[0]
    side_len = int(tokens**0.5)
    assert side_len * side_len == tokens, f"Tokens {tokens} not square"
    heatmap = activation.view(side_len, side_len).cpu().numpy()

    if np.allclose(heatmap.max(), heatmap.min()):
        heatmap = np.random.rand(side_len, side_len)
    else:
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

    # Upsample the attention values
    heatmap_upsampled = np.kron(heatmap, np.ones((upsample_factor, upsample_factor)))

    # Resize to match the image size
    heatmap_resized = Image.fromarray((heatmap_upsampled * 255).astype(np.uint8)).resize(
        gen_image.images[0].size, resample=Image.BILINEAR
    )
    heatmap_resized_np = np.array(heatmap_resized).astype(np.float32) / 255.0  # values in [0, 1]

    # Create overlay image
    jet = plt.cm.jet
    cmap = jet(np.arange(jet.N))
    cmap[:1, -1] = 0
    cmap[1:, -1] = 0.3
    cmap = ListedColormap(cmap)
    heatmap_rgba = cmap(heatmap_upsampled)
    heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8)).resize(
        gen_image.images[0].size, resample=Image.BILINEAR
    )

    overlay = Image.alpha_composite(gen_image.images[0].convert("RGBA"), heatmap_image)

    # Return overlay image and resized numeric attention map
    return overlay, heatmap_resized_np

def plot_all_attention_head_overlays(
    pipe_output,
    split_activations,
    num_of_attention_heads,
    prompt,
    epoch,
    folder,
    upsample_factor=8
):
    #Directory to store the attention head overlays for each epoch
    # dir_name= f"{prompt}_attention_overlays"
    # if !os.exists(dir_name):
    #     os.mkdir(dir_name)
    
    layers = list(split_activations.keys())
    num_layers = len(layers)
    max_heads = max(num_of_attention_heads.values())

    # Create a figure with a grid of subplots
    fig, axes = plt.subplots(num_layers, max_heads, figsize=(3 * max_heads, 3 * num_layers))

    # Ensure axes is a 2D numpy array
    if num_layers == 1 and max_heads == 1:
        axes = np.array([[axes]])
    elif num_layers == 1:
        axes = axes[np.newaxis, :]
    elif max_heads == 1:
        axes = axes[:, np.newaxis]

    for i, layer_name in enumerate(layers):
        n_heads = num_of_attention_heads[layer_name]
        head_activations = split_activations[layer_name]

        for j in range(max_heads):
            ax = axes[i][j]

            if j < n_heads:
                try:
                    overlay, _ = plot_attention_head_heatmap(
                        gen_image=pipe_output,
                        split_activations=split_activations,
                        layer_name=layer_name,
                        head_index=j,
                        upsample_factor=upsample_factor,
                    )
                    ax.imshow(overlay)
                    if i == 0:
                        ax.set_title(f"Head {j}", fontsize=10, pad=10)
                    if j == 0:
                        head = f"head.{j}"
                        ax.set_ylabel(
                            layer_name + f"\nTokens: {head_activations[head].shape[0]}",
                            fontsize=8, rotation=0, labelpad=100, va='center'
                        )
                except Exception as e:
                    ax.text(0.5, 0.5, "Error", ha="center", va="center", fontsize=8)
                    print(f"[WARN] Failed for {layer_name}, head {j}: {e}")
            else:
                ax.axis('off')

            ax.set_xticks([])
            ax.set_yticks([])

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.2, hspace=0.6)
    plt.suptitle(f"Attention Overlay Heatmaps for: '{prompt}'", fontsize=14)
    plt.savefig(f"{folder}/Attention_Overlay_Heads_{prompt}_epoch_{epoch}.png")
    print(f"Attention heads overlay for epoch {epoch} computed, and saved!")
    plt.show()
    
    
