# mvdream/attention_control.py
import os
import numpy as np
from PIL import Image
from functools import partial
from einops import rearrange, repeat

from typing import List, Dict, Optional
import torch
import torch.nn.functional as F

from .ldm.modules.attention import CrossAttention

# disappointed in myself - the qualitative behavior of generated output is still affected by this version of controller. and not similar to the true baseline from before without controller
# the expected behavior is that the controller should not affect the output at all even if its registered and passed as an artefact, but it does.
class AttentionStore:
    """
    A memory-efficient attention store that saves targeted attention maps to disk as images,
    driven by a configuration dictionary.
    """
    def __init__(self, output_dir: str, config: Optional[Dict] = None):
        self.output_dir = os.path.join(output_dir, "attention_maps")
        os.makedirs(self.output_dir, exist_ok=True)
        
        # --- Set defaults and override with config ---
        # Default: Final step for a 50-step DDIM process
        self.target_steps = [0, 49]
        # Default: Only the earliest, highest-res down-sampling blocks
        self.target_blocks = ['down']
        # Default: First two cross-attention layers found in the down-sampling blocks
        self.target_xattn_indices = [0, 1]
        
        if config:
            self.target_steps = config.get('steps', self.target_steps)
            self.target_blocks = config.get('blocks', self.target_blocks)
            self.target_xattn_indices = config.get('xattn_indices', self.target_xattn_indices)
            
        print(f"AttentionStore configured to save from steps: {self.target_steps}, blocks: {self.target_blocks}, xattn indices: {self.target_xattn_indices}")
        
        self.current_step = -1
        self.latest_attention_maps = {}
        self.mitigation_c_scale = None

    def set_current_step(self, step_idx: int):
        """Updates the current DDIM step."""
        self.current_step = step_idx

    def reset(self):
        self.current_step = -1
        self.latest_attention_maps = {} 

    def step(self):
        self.current_step += 1

    @staticmethod
    def save_attention_map(attn_map: torch.Tensor, save_path: str):
        try:
            cond_attn_map = attn_map[attn_map.shape[0] // 2:].mean(dim=0)
            hw = int(np.sqrt(cond_attn_map.shape[0]))
            num_tokens = cond_attn_map.shape[1]
            
            for i in range(num_tokens):
                token_map = cond_attn_map[:, i].reshape(hw, hw).cpu().numpy()
                token_map = (token_map - token_map.min()) / (token_map.max() - token_map.min() + 1e-6)
                token_image = Image.fromarray((token_map * 255).astype(np.uint8))
                
                # Save each token's map to a unique file
                token_image.save(save_path.replace('.png', f'_token{i:02d}.png'))
        except Exception as e:
            print(f"[AttentionStore ERROR] Failed to save attention map. Error: {e}")

        
    def __call__(self, attn_probs, place_in_unet: str, xattn_idx: int, layer_name: str):
        """The hook now saves to disk AND stores the latest map in memory."""
        key = f"{place_in_unet}_{xattn_idx}"
        
        # 1. Store the tensor in memory for immediate access
        if self.current_step not in self.latest_attention_maps:
            self.latest_attention_maps[self.current_step] = {}
        self.latest_attention_maps[self.current_step][key] = attn_probs.clone().detach()

        # 2. Save the map to disk if it meets the criteria from the config
        # print(self.current_step, "/", self.target_steps)
        # print(place_in_unet, "/", self.target_blocks)
        # print(xattn_idx, "/", self.target_xattn_indices)

        if (self.current_step in self.target_steps and
            place_in_unet in self.target_blocks and
            (self.target_xattn_indices == 'all' or xattn_idx in self.target_xattn_indices)):
            
            # Sanitize layer name for filename
            safe_layer_name = layer_name.replace('.', '_')
            save_path = os.path.join(self.output_dir, f"step{self.current_step}_{safe_layer_name}.png")
            self.save_attention_map(attn_probs.detach(), save_path)
            # print("Saved Attention to ", save_path)


def register_attention_control(model, controller: AttentionStore):
    """
    Replaces the forward methods of CrossAttention layers with a hooked version.
    This version now tracks the index of each cross-attention layer.
    """
    
    def ca_forward_hook(self, x, context=None, mask=None, place_in_unet=None, xattn_idx=None, layer_name=None):
        """
        A wrapper for the original CrossAttention.forward method that
        captures the attention probabilities.
        """
        h = self.heads
        q = self.to_q(x)
        context = context if context is not None else x
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        if mask is not None:
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, sim.shape[-1] - mask.shape[-1]), 'constant', 0)
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask.bool(), max_neg_value)

        attn_probs = sim.softmax(dim=-1)
        
        # --- HOOK LOGIC ---
        if controller is not None:
            controller(attn_probs, place_in_unet, xattn_idx, layer_name)
        # ------------------

        out = torch.einsum('b i j, b j d -> b i d', attn_probs, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

    print("Registering attention control hooks for MVDream...")
    # Keep track of the index for each block type
    xattn_counters = {'down': 0, 'mid': 0, 'up': 0}

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, CrossAttention) and "attn2" in name:
            if "input_blocks" in name: place = "down"
            elif "middle_block" in name: place = "mid"
            elif "output_blocks" in name: place = "up"
            else: continue
            
            # Get the current index for this layer and increment the counter
            current_xattn_idx = xattn_counters[place]
            xattn_counters[place] += 1
            
            # Bind all necessary metadata to the new forward function
            module.forward = partial(ca_forward_hook, module,
                                     place_in_unet=place,
                                     xattn_idx=current_xattn_idx,
                                     layer_name=name)