from functools import partial
from typing import Optional
import torch
from tqdm import tqdm
from PIL import Image

from utils.activation_detection import prepare_diffusion_inputs 
from hooks.block_activations import RescaleLinearActivations
from hooks.wanda import wanda_blocking_hook_fn

class MemMitigation:
    def __init__(self, unet):
        self.unet = unet
        self.block_handles = []

    def apply(self):
        raise NotImplementedError("This method should be overridden by subclasses")

    def remove(self):
        """Remove all registered hooks from the model."""
        for handle in self.block_handles:
            handle.remove()


class Wanda(MemMitigation):
    def __init__(self, unet, masking_matrices):
        super().__init__(unet=unet)
        self.masking_matrices = masking_matrices

    def apply(self):
        block_handles = []
        num_blocked_weights = 0
        for name, module in self.unet.named_modules():
            if name in self.masking_matrices.keys() and isinstance(module, torch.nn.Linear):
                block_handle = module.register_forward_hook(partial(wanda_blocking_hook_fn, binary_mask=self.masking_matrices[name]))
                block_handles.append(block_handle)
                num_blocked_weights += (self.masking_matrices[name] == 0).sum().item()

        self.block_handles = block_handles
        print(f'Number of blocked weights using Wanda: {num_blocked_weights}')

class Nemo(MemMitigation):
    def __init__(self, unet, blocked_indices=None, scaling_factor=0.0):
        super().__init__(unet=unet)
        self.blocked_indices = blocked_indices
        self.scaling_factor = scaling_factor

    def apply(self):
        block_handles = []
        block_hooks = []

        # Register hooks for down blocks
        for down_block in range(3):
            for attention in range(2):
                indices = self.blocked_indices[down_block * 2 + attention]
                block_hook = RescaleLinearActivations(indices=indices, factor=self.scaling_factor)
                block_handle = self.unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
                block_handles.append(block_handle)
                block_hooks.append(block_hook)

        # Register hooks for mid block
        block_hook = RescaleLinearActivations(indices=self.blocked_indices[-1], factor=self.scaling_factor)
        block_handle = self.unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
        block_handles.append(block_handle)
        block_hooks.append(block_hook)    

        print(f'Number of blocked value neurons using NeMo: {sum([len(block_hook.indices) for block_hook in block_hooks])}')
        self.block_handles = block_handles