import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

attn_maps = {}

qs = {}
ks = {}


def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "q"):
            qs[name] = module.processor.q
            del module.processor.q
        if hasattr(module.processor, "k"):
            ks[name] = module.processor.k
            del module.processor.k
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook


def register_attention_hook(unet, attn_key):
    count = 0
    for name, module in unet.named_modules():
        if name.split(".")[-1].endswith(attn_key):
            module.register_forward_hook(hook_fn(name))
            count += 1
    print(f"Register {count} {attn_key}s with Hook")

    return unet


def upscale(attn_map, target_size):
    temp_size = None

    for i in range(0, 5):
        scale = 2**i
        if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[1] * 64:
            temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode='bilinear', align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map


def get_net_qk(batch_size=32, instance_or_negative=False, detach=True, layerwise_average=False):
    return qs, ks

def get_net_attn_map(image_size, batch_size=32, instance_or_negative=False, detach=True, layerwise_average=False):
    net_attn_maps = []
    
    # Define the indices to select either instance (second half) or negative (first half)
    if instance_or_negative:
        selected_indices = range(batch_size // 2, batch_size)  # Select second half for instance
    else:
        selected_indices = range(0, batch_size // 2)  # Select first half for negative
    
    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map  # [bs, num_heads, seq_len, num_tokens]
        
        # Split the attention maps based on selected indices (negative or instance)
        selected_chunks = [torch.chunk(attn_map, batch_size)[i] for i in selected_indices]
        
        # Process and store attention maps for each frame separately
        frame_attn_maps = []
        for chunk in selected_chunks:  # Process attention maps for selected frames
            chunk = chunk.squeeze()  # [num_heads, seq_len, num_tokens]
            
            if layerwise_average:
                attn_map_avg = torch.mean(chunk, dim=2)  # [num_heads, seq_len], Layerwise Average
            else:
                attn_map_avg = torch.mean(chunk, dim=0)  # [seq_len, num_tokens], Tokenwise Average
                attn_map_avg = attn_map_avg.permute(1, 0)  # Permute to match dimensions
            
            attn_map_rescaled = upscale(attn_map_avg, image_size)  # Rescale to target image size
            frame_attn_maps.append(attn_map_rescaled)
        
        net_attn_maps.append(torch.stack(frame_attn_maps))  # Stack attention maps for this layer
    
    # Return the attention maps for all frames, with each frame kept separate
    return net_attn_maps


def attnmaps2images(net_attn_maps):

    total_attn_scores = {}
    images = []

    for idx, attn_map in enumerate(net_attn_maps):
        attn_map = attn_map.cpu().numpy()
        total_attn_scores[idx] = attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        image = Image.fromarray(normalized_attn_map)

        images.append(image)

    return images, total_attn_scores


def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator
