# mvdream/attention_control.py
import os
import numpy as np
from PIL import Image
from functools import partial
from einops import rearrange
from typing import Dict, Optional
import torch

from .ldm.modules.attention import CrossAttention


class AttentionController:
    """
    A unified, memory-efficient controller that can BOTH passively store attention
    maps for metrics (like AttentionStore) AND actively apply mitigation
    strategies during inference.

    Its behavior is determined by the `mitigation_params` at initialization.
    - If `mitigation_params` is None, it's a PASSIVE observer.
    - If `mitigation_params` is provided, it's an ACTIVE mitigator.
    """
    def __init__(self, output_dir: str, config: Optional[Dict] = None, mitigation_params: Optional[Dict] = None):
        # --- Configuration for PASSIVE saving of attention maps (for metrics) ---
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        
        # Default configuration for which attention maps to save
        default_config = {
            'steps': [0, 49],
            'blocks': ['down'],
            'xattn_indices': [0, 1, 2, 3]
        }
        self.config = config if config is not None else default_config
        
        # --- Configuration for ACTIVE mitigation ---
        if mitigation_params is None:
            mitigation_params = {}
        self.mitigation_type = mitigation_params.get('type', None)
        self.mitigation_c_scale = mitigation_params.get('c_scale', 1.25) # Example for 'ca_entropy'
        
        if self.mitigation_type:
            print(f"[AttentionController] Initialized in ACTIVE mode. Type: {self.mitigation_type}")
        else:
            print("[AttentionController] Initialized in PASSIVE mode. Only recording attention maps.")

        self.reset()

    def reset(self):
        """Resets the state of the controller for a new generation."""
        self.current_step = 0
        self.latest_attention_maps = {}

    def set_current_step(self, step_idx: int):
        """Updates the current DDIM step."""
        self.current_step = step_idx

    def apply_mitigation(self, attn_logits: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        """
        Applies the configured mitigation strategy to the attention logits.
        This method is only called by the forward hook if mitigation is active.
        """
        if self.mitigation_type == 'ca_entropy':
            # As per Ren et al., scale the <bos> token and mask out summary tokens
            attn_logits[:, :, 0] *= self.mitigation_c_scale
            num_tokens = context.shape[1]
            if num_tokens < attn_logits.shape[-1]:
                attn_logits[:, :, num_tokens:] = -torch.finfo(attn_logits.dtype).max
        
        # Other mitigation types could be added here with `elif`
        
        return attn_logits

    def record_attention(self, attn_probs: torch.Tensor, place: str, idx: int, name: str):
        """
        Handles the passive part: caching the latest attention maps in memory
        and saving them to disk if they meet the configured criteria.
        """
        # 1. Cache the latest map in memory for metric calculation
        key = f"{place}_{idx}"
        if self.current_step not in self.latest_attention_maps:
            self.latest_attention_maps[self.current_step] = {}
        self.latest_attention_maps[self.current_step][key] = attn_probs.detach().cpu()

        # 2. Save the map to disk if it meets the criteria from the config
        if (self.current_step in self.config.get('steps', []) and
            place in self.config.get('blocks', []) and
            (self.config.get('xattn_indices') == 'all' or idx in self.config.get('xattn_indices', []))):
            
            safe_layer_name = name.replace('.', '_')
            save_path = os.path.join(self.output_dir, f"step{self.current_step}_{safe_layer_name}.png")
            self._save_attention_map_image(attn_probs.detach().cpu(), save_path)

    @staticmethod
    def _save_attention_map_image(attn_map: torch.Tensor, save_path: str):
        """Helper function to process and save attention maps as images."""
        try:
            # We typically visualize the conditional map's attention
            cond_attn_map = attn_map[attn_map.shape[0] // 2:].mean(dim=0)
            hw = int(np.sqrt(cond_attn_map.shape[0]))
            num_tokens = cond_attn_map.shape[1]
            
            for i in range(num_tokens):
                token_map = cond_attn_map[:, i].reshape(hw, hw).cpu().numpy()
                token_map_normalized = (token_map - token_map.min()) / (token_map.max() - token_map.min() + 1e-6)
                Image.fromarray((token_map_normalized * 255).astype(np.uint8)).save(save_path.replace('.png', f'_token{i:02d}.png'))
        except Exception as e:
            print(f"[AttentionController ERROR] Failed to save attention map image. Error: {e}")


def register_attention_control(model, controller: AttentionController):
    """
    Replaces the forward methods of CrossAttention layers with a new, unified hook
    that handles both passive recording and active mitigation.
    """
    def ca_forward_hook(self, x, context=None, mask=None, place_in_unet=None, xattn_idx=None, layer_name=None):
        """
        A wrapper for the original CrossAttention.forward method that uses the
        unified AttentionController.
        """
        h = self.heads
        q = self.to_q(x)
        context = context if context is not None else x
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        
        # 1. Calculate the raw attention logits
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        # 2. (ACTIVE) Apply mitigation to logits if controller is in active mode
        if controller.mitigation_type is not None:
            sim = controller.apply_mitigation(sim, context)
        
        # 3. Calculate final attention probabilities
        attn_probs = sim.softmax(dim=-1)
        
        # 4. (PASSIVE) Record the final probabilities for metric calculation
        controller.record_attention(attn_probs, place_in_unet, xattn_idx, layer_name)

        # 5. Compute the output
        out = torch.einsum('b i j, b j d -> b i d', attn_probs, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

    print("Registering unified attention control hooks...")
    xattn_counters = {'down': 0, 'mid': 0, 'up': 0}

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, CrossAttention) and "attn2" in name:
            if "input_blocks" in name: place = "down"
            elif "middle_block" in name: place = "mid"
            elif "output_blocks" in name: place = "up"
            else: continue
            
            current_xattn_idx = xattn_counters[place]
            xattn_counters[place] += 1
            
            # Bind all necessary metadata to the new forward function
            module.forward = partial(ca_forward_hook, module,
                                     place_in_unet=place,
                                     xattn_idx=current_xattn_idx,
                                     layer_name=name)
    print("Attention control hooks registered successfully.")
