import torch
from tqdm import tqdm
from functools import partial
from PIL import Image

from utils.activation_detection import prepare_diffusion_inputs
from hooks.wanda import input_norm_hook_fn, wanda_blocking_hook_fn


def set_wanda_input_norm_hooks(
    unet,
    layers,
    uncond_input_norm_dict,
    cond_input_norm_dict,
    hook_fn=input_norm_hook_fn,
    num_inference_steps=50,
    current_time_step=0,
    early_stopping=None,
):
    block_handles = []
    for name, module in unet.named_modules():
        if name in layers and isinstance(module, torch.nn.Linear):
            block_handle = module.register_forward_hook(
                partial(
                    hook_fn,
                    uncond_input_norm_list_for_layer=uncond_input_norm_dict[name],
                    cond_input_norm_list_for_layer=cond_input_norm_dict[name],
                    t=current_time_step,
                    total_num_timesteps=(
                        early_stopping
                        if early_stopping is not None
                        else num_inference_steps
                    ),
                )
            )
            block_handles.append((name, block_handle))

    return block_handles


def set_wanda_blocking_hooks(unet, binary_masks=None, hook_fn=wanda_blocking_hook_fn):
    block_handles = []
    if binary_masks is not None:
        for name, module in unet.named_modules():
            if name in binary_masks.keys() and isinstance(module, torch.nn.Linear):
                block_handle = module.register_forward_hook(
                    partial(hook_fn, binary_mask=binary_masks[name])
                )
                block_handles.append((name, block_handle))
    return block_handles


class hook_cxt_manager:
    def __init__(
        self,
        hook_fn,
    ):
        self.hook_fn = hook_fn
        self.block_handles = []

    def __enter__(self):
        self.block_handles = self.hook_fn()
        return self.block_handles

    def __exit__(self, exc_type, exc_value, traceback):
        """
        Removes the hook function from the specified layers in the UNet.
        """
        for handle in self.block_handles:
            handle[1].remove()

        return False


@torch.no_grad()
def get_input_norms(
    prompts,
    tokenizer,
    text_encoder,
    unet,
    scheduler,
    guidance_scale,
    seed,
    samples_per_prompt,
    num_inference_steps=50,
    blocks=[True] * 16,
    rtpt=None,
    early_stopping=None,
    verbose=True,
    text_embeddings_list=None,
):
    """
    Returns the input norms for the unconditioned and conditioned inputs of the UNet for the given prompts.
    The norms are aggregated over all the given samples, but not over the timesteps, which results in input norms for each time steps.
    """
    assert len(blocks) == 16, f"Expected 16 blocks, but got {len(blocks)}"

    # create dictionaries to save the conditional and unconditional activation norms
    cond_input_norm_dict = {
        name: []
        for name, module in unet.named_modules()
        if "ff.net.2" in name and isinstance(module, torch.nn.Linear)
    }
    uncond_input_norm_dict = {
        name: []
        for name, module in unet.named_modules()
        if "ff.net.2" in name and isinstance(module, torch.nn.Linear)
    }

    if rtpt is not None:
        rtpt.start()

    # filter the blocks
    for block_idx, block_name in enumerate(sorted(cond_input_norm_dict.keys())):
        if blocks[block_idx]:
            cond_input_norm_dict[block_name] = []
            uncond_input_norm_dict[block_name] = []
        else:
            del cond_input_norm_dict[block_name]
            del uncond_input_norm_dict[block_name]

    # if blocks == 'down':
    #     cond_input_norm_dict = {k: v for k, v in cond_input_norm_dict.items() if 'down_blocks' in k or 'mid_block' in k}
    #     uncond_input_norm_dict = {k: v for k, v in uncond_input_norm_dict.items() if 'down_blocks' in k or 'mid_block' in k}

    if text_embeddings_list is not None:
        prompts = text_embeddings_list

    for prompt in tqdm(prompts, disable=not verbose):
        latents, text_embeddings = prepare_diffusion_inputs(
            [prompt] if text_embeddings_list is None else ["text_embeddings_list"],
            tokenizer,
            text_encoder,
            unet,
            guidance_scale=guidance_scale,
            samples_per_prompt=samples_per_prompt,
            seed=seed,
        )
        scheduler.set_timesteps(num_inference_steps)

        if text_embeddings_list is not None:
            uncond_input = tokenizer(
                [""] * len(prompts),
                padding="max_length",
                max_length=text_embeddings.shape[-2],
                return_tensors="pt",
            )
            uncond_embeddings = text_encoder(
                uncond_input.input_ids.to(text_encoder.device)
            )[0]
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
            text_embeddings = prompt

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            for i, t in enumerate(scheduler.timesteps):
                if early_stopping is not None and i >= early_stopping:
                    break

                # add the hook to the ff layers
                with hook_cxt_manager(
                    partial(
                        set_wanda_input_norm_hooks,
                        unet=unet,
                        layers=cond_input_norm_dict.keys(),
                        uncond_input_norm_dict=uncond_input_norm_dict,
                        cond_input_norm_dict=cond_input_norm_dict,
                        hook_fn=input_norm_hook_fn,
                        num_inference_steps=num_inference_steps,
                        current_time_step=i,
                        early_stopping=early_stopping,
                    )
                ):
                    if guidance_scale == 0:
                        latent_model_input = latents
                    else:
                        latent_model_input = torch.cat([latents] * 2)

                    latent_model_input = scheduler.scale_model_input(
                        latent_model_input, t
                    )

                    with torch.no_grad():
                        noise_pred = unet(
                            latent_model_input.cuda(),
                            t,
                            encoder_hidden_states=text_embeddings,
                            return_dict=False,
                        )[0]

                    if guidance_scale != 0:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (
                            noise_pred_text - noise_pred_uncond
                        )

                    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[
                        0
                    ]

                    torch.cuda.empty_cache()

        if rtpt is not None:
            rtpt.step()

    # move all the dictionaries to the cpu
    for key in uncond_input_norm_dict.keys():
        uncond_input_norm_dict[key] = [
            norm.cpu() for norm in uncond_input_norm_dict[key]
        ]
        cond_input_norm_dict[key] = [norm.cpu() for norm in cond_input_norm_dict[key]]

    return uncond_input_norm_dict, cond_input_norm_dict


@torch.no_grad()
def get_wanda_scores(abs_weight, input_norms):
    return abs_weight * input_norms


@torch.no_grad()
def get_masking_matrices(
    unet,
    uncond_input_norms,
    cond_input_norms,
    percentage_of_neurons_to_prune=0.01,
    timesteps_used=None,
    verbose=True,
):
    # get the absolute weights for each of the FF layers
    abs_layer_weights = {}
    for name, module in unet.named_modules():
        if name in uncond_input_norms.keys():
            abs_layer_weights[name] = module.weight.detach().abs().cpu()

    masking_matrices = {}
    for layer_name in tqdm(
        uncond_input_norms.keys(), desc="Creating masking matrices", disable=not verbose
    ):
        masking_matrix = torch.ones_like(abs_layer_weights[layer_name])
        for timestep in range(len(uncond_input_norms[layer_name])):
            if timesteps_used is not None and timestep >= timesteps_used:
                break

            # get the wanda scores for all the neurons in the current layer for the current time step
            wanda_scores_uncond = get_wanda_scores(
                abs_layer_weights[layer_name], uncond_input_norms[layer_name][timestep]
            )
            wanda_scores_cond = get_wanda_scores(
                abs_layer_weights[layer_name], cond_input_norms[layer_name][timestep]
            )

            # check for any inf values
            if (
                torch.isinf(wanda_scores_uncond).any()
                or torch.isinf(wanda_scores_cond).any()
            ):
                print("Inf values detected")

            # sort the wanda values and get the top neurons based on the sparsity level
            values, neuron_indices_uncond = torch.sort(
                wanda_scores_uncond, dim=1, descending=True
            )
            indices_to_prune_uncond = neuron_indices_uncond[
                :, : int(percentage_of_neurons_to_prune * wanda_scores_uncond.shape[1])
            ]
            values, neuron_indices_cond = torch.sort(wanda_scores_cond, descending=True)
            indices_to_prune_cond = neuron_indices_cond[
                :, : int(percentage_of_neurons_to_prune * wanda_scores_cond.shape[1])
            ]

            # create the binary masks for the neurons to prune
            binary_mask_cond = torch.zeros_like(abs_layer_weights[layer_name])
            binary_mask_cond.scatter_(1, indices_to_prune_cond, 1)
            diff = wanda_scores_cond > wanda_scores_uncond
            binary_mask = diff * binary_mask_cond

            # the neurons that we want to prune now have a 1 in the binary mask
            # therefore, invert the binary mask to keep all other neurons and zero out the memorization neurons
            binary_mask = 1 - binary_mask

            # print(f'Density of the mask: {np.mean(binary_mask.cpu().numpy())}')

            # perform or operation to aggregate the binary masks for all the timesteps of a layer
            merged_mask = (binary_mask.to(bool) & masking_matrix.to(bool)).to(
                torch.float32
            )
            masking_matrix = merged_mask.clamp(0, 1)

        masking_matrices[layer_name] = masking_matrix

    return masking_matrices
