# mvdream/editing/nemo.py
import torch
import torch.nn as nn
from torchmetrics.functional import structural_similarity_index_measure
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
import numpy as np

class NeMoEditor:
    """
    Implements the NeMo algorithm to find and deactivate memorization neurons
    for a single prompt at inference time. Based on "Finding Nemo..." (Hintersdorf et al., 2024).
    """
    def __init__(self, 
                 model, 
                 non_mem_prompts: List[str], 
                 device='cuda',
                 ssim_threshold: float = 0.5,       # Stronger: lower. The target similarity score to fall below.
                 initial_theta: float = 6.0,        # Stronger: lower. The initial z-score for OOD neuron detection.
                 min_theta: float = 1.2,            # Stronger: lower. The minimum z-score to reach before stopping.
                 theta_step: float = 0.25,          # Stronger: larger. How quickly to decrease theta in each step.
                 initial_k: int = 5,                # Stronger: higher. Initial number of top-activated neurons to select.
                 k_step: int = 1,                   # Stronger: larger. How quickly to increase k in each step.
                 num_ssim_seeds: int = 10):         # More robust measurement: higher. Does not directly control strength.
        self.model = model
        self.device = device
        self.hooks = []
        
        # ADDED: Store hyperparameters as instance variables
        self.ssim_threshold = ssim_threshold
        self.initial_theta = initial_theta
        self.min_theta = min_theta
        self.theta_step = theta_step
        self.initial_k = initial_k
        self.k_step = k_step
        self.num_ssim_seeds = num_ssim_seeds
        
        self.value_layers = self._get_value_layers()
        print("Pre-computing activation statistics on non-memorized prompts for NeMo...")
        self.non_mem_stats = self._precompute_activation_stats(non_mem_prompts)
        print("NeMoEditor initialized.")

    def _get_value_layers(self) -> Dict[str, nn.Module]:
        """Gets the value projection layers from cross-attention blocks."""
        layers = {}
        for name, module in self.model.model.diffusion_model.named_modules():
            if 'attn2' in name and name.endswith('to_v'):
                layers[name] = module
        return layers

    @torch.no_grad()
    def _collect_activations(self, prompts: List[str], block_neurons: Dict = None) -> Dict[str, torch.Tensor]:
        """Runs a single denoising step to collect intermediate value layer activations."""
        if block_neurons: self.register_hooks(block_neurons)
        
        activations = {name: [] for name in self.value_layers}
        hooks = []

        def get_hook(name):
            def hook(module, input, output):
                activations[name].append(output.detach())
            return hook

        for name, layer in self.value_layers.items():
            hooks.append(layer.register_forward_hook(get_hook(name)))
        
        c = self.model.get_learned_conditioning(prompts).to(self.device)
        t = torch.tensor([999], device=self.device)
        latents = torch.randn((len(prompts), 4, 32, 32), device=self.device, dtype=self.model.dtype)
        c_ = {"context": c}
        self.model.apply_model(latents, t, c_)

        for hook in hooks: hook.remove()
        if block_neurons: self.remove_hooks()
        
        for name, act_list in activations.items():
            activations[name] = torch.cat(act_list, dim=0)
        return activations

    def _precompute_activation_stats(self, prompts: List[str]) -> Dict[str, Dict[str, torch.Tensor]]:
        """Computes mean and std of activations on a hold-out set of non-memorized prompts."""
        activations = self._collect_activations(prompts)
        stats = {}
        for name, act_tensor in activations.items():
            # Paper Averages across the absolute neuron activations for each token vector
            abs_activations = torch.abs(act_tensor).mean(dim=1) # Avg across token dim
            stats[name] = {'mean': torch.mean(abs_activations, dim=0), 'std': torch.std(abs_activations, dim=0)}
        return stats
    
    @torch.no_grad()
    def _get_noise_diffs(self, prompt: str, num_seeds: int, neurons_to_block: Dict = None) -> torch.Tensor:
        """Computes the initial noise differences (delta) for multiple seeds."""
        if neurons_to_block: self.register_hooks(neurons_to_block)

        c = self.model.get_learned_conditioning([prompt]).to(self.device)
        t = torch.tensor([999], device=self.device)
        uc = self.model.get_learned_conditioning([""]).to(self.device)
        c_ = {"context": c}; uc_ = {"context": uc}
        
        noise_diffs = []
        for i in range(num_seeds):
            generator = torch.Generator(device=self.device).manual_seed(i)
            x_T = torch.randn((1, 4, 32, 32), generator=generator, device=self.device, dtype=self.model.dtype)
            cond_noise = self.model.apply_model(x_T, t, c_)
            uncond_noise = self.model.apply_model(x_T, t, uc_)
            # The paper uses just the conditional noise prediction vs initial noise
            noise_diff = cond_noise - x_T #
            noise_diffs.append(noise_diff)
            
        if neurons_to_block: self.remove_hooks()
        return torch.cat(noise_diffs, dim=0)

    def _compute_memorization_score(self, prompt: str, neurons_to_block: Dict = None) -> float:
        """Calculates the max pairwise SSIM score."""
        # MODIFIED: Uses the class attribute `self.num_ssim_seeds`
        noise_diffs = self._get_noise_diffs(prompt, num_seeds=self.num_ssim_seeds, neurons_to_block=neurons_to_block)
        
        max_ssim = 0.0
        for i in range(len(noise_diffs)):
            for j in range(i + 1, len(noise_diffs)):
                # SSIM expects values in [0, 1] or [-1, 1], so we normalize
                img1 = (noise_diffs[i] - noise_diffs[i].min()) / (noise_diffs[i].max() - noise_diffs[i].min())
                img2 = (noise_diffs[j] - noise_diffs[j].min()) / (noise_diffs[j].max() - noise_diffs[j].min())
                ssim = structural_similarity_index_measure(img1.unsqueeze(0), img2.unsqueeze(0), data_range=1.0)
                max_ssim = max(max_ssim, ssim.item())
        return max_ssim
    
    def register_hooks(self, neurons_to_block: Dict):
        """Registers forward hooks to deactivate specified neurons."""
        self.remove_hooks()
        def get_hook(indices_to_block):
            def hook(module, input, output):
                output[:, :, indices_to_block] = 0.0
                return output
            return hook

        for name, layer in self.value_layers.items():
            if name in neurons_to_block and len(neurons_to_block[name]) > 0:
                self.hooks.append(layer.register_forward_hook(get_hook(neurons_to_block[name])))

    def remove_hooks(self):
        for hook in self.hooks: hook.remove()
        self.hooks = []

    # MODIFIED: Method signature simplified, but allows overrides for flexibility.
    # It now uses the class attributes for hyperparameters by default.
    def find_neurons(self, prompt: str, 
                     ssim_threshold: Optional[float] = None, 
                     initial_theta: Optional[float] = None, 
                     min_theta: Optional[float] = None, 
                     initial_k: Optional[int] = None):
        """
        Implements the full Initial Selection + Refinement process from NeMo.
        This is a simplified version for demonstration. A full implementation is highly complex.
        """
        print(f"Finding NeMo neurons for prompt: '{prompt[:50]}...'")

        # MODIFIED: Use instance variables as defaults, but allow overriding them
        ssim_thresh = ssim_threshold if ssim_threshold is not None else self.ssim_threshold
        theta = initial_theta if initial_theta is not None else self.initial_theta
        min_th = min_theta if min_theta is not None else self.min_theta
        k = initial_k if initial_k is not None else self.initial_k

        # --- Step 1: Initial Candidate Selection ---
        # Get activations for the target prompt
        prompt_activations = self._collect_activations([prompt])
        
        candidates = {}
        
        while True:
            candidates = {}
            for name, layer in self.value_layers.items():
                # OOD detection via z-score
                mean_abs_act = torch.abs(prompt_activations[name]).mean(dim=1).squeeze()
                z_scores = (mean_abs_act - self.non_mem_stats[name]['mean']) / (self.non_mem_stats[name]['std'] + 1e-6)
                ood_indices = (z_scores > theta).nonzero().squeeze(-1).tolist()
                
                # Top-k detection
                top_k_indices = torch.topk(mean_abs_act, k=min(k, len(mean_abs_act))).indices.tolist()
                
                candidates[name] = sorted(list(set(ood_indices + top_k_indices)))

            # Check if memorization is mitigated
            current_ssim = self._compute_memorization_score(prompt, neurons_to_block=candidates)
            print(f"  > Iteration: theta={theta:.2f}, k={k}. Current SSIM: {current_ssim:.4f} (Threshold: {ssim_thresh})")
            
            if current_ssim < ssim_thresh or theta <= min_th:
                break
            
            # MODIFIED: Use class attributes for step sizes
            theta -= self.theta_step # Iteratively decrease theta
            k += self.k_step       # Iteratively increase k

        # --- Step 2: Refinement ---
        # The paper describes a greedy process of re-activating neurons/layers and checking
        # if the SSIM score stays low. This is computationally intensive.
        # For this implementation, we will consider the initial candidates as the final set.
        print(f"Found {sum(len(v) for v in candidates.values())} candidate neurons. Refinement step skipped for this implementation.")
        
        return candidates