# mvdream/editing/subspace_pruner.py
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Dict

class SubspacePruner:
    """
    Implements Subspace Pruning based on "Memorized Images... share a Subspace".
    This method identifies and prunes weights in FFN layers that are critical for memorization.
    """
    def __init__(self, model, device='cuda', sparsity: float = 0.001):
        self.model = model
        self.device = device
        self.sparsity = sparsity
        self.ffn_layers = self._get_ffn_layers()
        print(f"Found {len(self.ffn_layers)} FFN layers to analyze.")

    def _get_ffn_layers(self) -> Dict[str, nn.Module]:
        """Finds the second linear layer of FFNs in the U-Net's transformer blocks."""
        ffn_layers = {}
        for name, module in self.model.model.diffusion_model.named_modules():
            # The paper targets the second linear layer in the FFN, often named proj_out or net.2
            if name.endswith('ff.net.2'): 
                ffn_layers[name] = module
        return ffn_layers

    @torch.no_grad()
    def _collect_activations(self, prompts: List[str]) -> Dict[str, torch.Tensor]:
        """Runs a single denoising step to collect intermediate FFN activations."""
        activations = {name: [] for name in self.ffn_layers}
        hooks = []

        def get_hook(name):
            def hook(model, input, output):
                # The input to the second FFN layer is the activation tensor 'h'.
                # Its shape is (batch_size * sequence_length, features).
                # We move it to CPU immediately to conserve VRAM.
                activations[name].append(input[0].detach().cpu())
            return hook

        for name, layer in self.ffn_layers.items():
            hooks.append(layer.register_forward_hook(get_hook(name)))

        # Run one step of denoising to trigger the hooks
        t = torch.tensor([999], device=self.device)
        # The number of prompts defines the batch size for this operation
        batch_size = len(prompts)
        c = self.model.get_learned_conditioning(prompts).to(self.device)
        # We need to know the sequence length of the latents to reshape correctly later
        num_latent_tokens = (self.model.image_size // 8) ** 2
        
        latents = torch.randn((batch_size, 4, self.model.image_size // 8, self.model.image_size // 8), device=self.device, dtype=self.model.dtype)
        c_ = {"context": c}
        self.model.apply_model(latents, t, c_)

        for hook in hooks:
            hook.remove()
        
        processed_activations = {}
        for name, act_list in activations.items():
            # act_list contains one tensor of shape (batch * seq_len, in_dim)
            act_tensor = act_list[0]
            in_dim = act_tensor.shape[-1]
            
            # Reshape to (batch, seq_len, in_dim) to separate the dimensions
            act_tensor = act_tensor.view(batch_size, -1, in_dim)
            
            # As per the paper's logic, we need one activation vector per prompt.
            # We achieve this by averaging over the sequence/token dimension.
            act_tensor_avg = act_tensor.mean(dim=1) # Shape: (batch, in_dim)
            
            # Transpose to get the final matrix H of shape (in_dim, batch_size)
            processed_activations[name] = act_tensor_avg.T
            
        return processed_activations

    def find_memorization_subspace(self, memorized_prompts: List[str]):
        """Finds the shared memorization subspace across a set of prompts."""
        print(f"Finding memorization subspace using {len(memorized_prompts)} prompts...")
        mem_activations = self._collect_activations(memorized_prompts)
        null_activations = self._collect_activations([""] * len(memorized_prompts))
        
        pruning_masks = {}
        for name, layer in self.ffn_layers.items():
            W = layer.weight.data.float()
            H_mem = mem_activations[name].to(self.device, dtype=torch.float32)
            H_null = null_activations[name].to(self.device, dtype=torch.float32)

            # S = |W| * ||H||
            # W shape: (out_features, in_features)
            # H shape: (in_features, num_prompts)
            # norm(H, dim=1) shape: (in_features,)
            # The multiplication broadcasts the norm vector across the rows of |W|.
            S_mem = torch.abs(W) * torch.linalg.norm(H_mem, dim=1)
            S_null = torch.abs(W) * torch.linalg.norm(H_null, dim=1)

            s_abs = int(W.shape[1] * self.sparsity)
            if s_abs == 0: continue
            
            threshold = torch.topk(S_mem, s_abs, dim=1).values[:, -1].unsqueeze(1)
            # threshold = torch.quantile(S_mem, 1.0 - self.sparsity, dim=1)

            is_top_s = S_mem >= threshold
            is_memorized_neuron = is_top_s & (S_mem > S_null)
            
            pruning_masks[name] = is_memorized_neuron # Less aggressive: S_null * 0.8
            
        return pruning_masks

    def prune_model_weights(self, pruning_masks: Dict[str, torch.Tensor]):
        """Applies the pruning masks to the model, zeroing out weights."""
        print("Pruning model weights...")
        for name, mask in pruning_masks.items():
            layer = self.ffn_layers[name]
            layer.weight.data[mask] = 0.0
        print("Model pruning complete.")