from utils.stable_diffusion import load_sd_components, load_text_components, generate_images
import torch
from tqdm.auto import tqdm
from utils.datasets import load_prompts
from hooks.collect_activations import CollectActivationsLinearNoMean, CollectActivationsValueOutputLayer
from hooks.block_activations import RescaleLinearActivations
from torchvision.utils import save_image
import copy
from torchmetrics.functional import total_variation, structural_similarity_index_measure, multiscale_structural_similarity_index_measure, pairwise_cosine_similarity


@torch.no_grad()
def prepare_diffusion_inputs(prompts, tokenizer, text_encoder, unet, guidance_scale, samples_per_prompt, seed):
    height = 512
    width = 512
    generator = torch.manual_seed(seed)
    if samples_per_prompt > 1:
        prompts = [prompt for prompt in prompts for _ in range(samples_per_prompt)]                
    text_input = tokenizer(prompts,
                            padding="max_length",
                            max_length=tokenizer.model_max_length,
                            truncation=True,
                            return_tensors="pt")
    text_embeddings = text_encoder(
        text_input.input_ids.to(text_encoder.device))[0]

    latents = torch.randn(
        (len(prompts), unet.config.in_channels, height // 8, width // 8),
        generator=generator,
    )
    
    if guidance_scale != 0:
        max_length = text_input.input_ids.shape[-1]
        uncond_input = tokenizer([""] * len(prompts),
                                    padding="max_length",
                                    max_length=max_length,
                                    return_tensors="pt")
        uncond_embeddings = text_encoder(
            uncond_input.input_ids.to(text_encoder.device))[0]
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    latents = latents.to(text_embeddings.device)
    return latents, text_embeddings


# run the denoising process to collect the activations with a hook (has to be added beforehand)
@torch.no_grad()
def collect_activations(prompts, tokenizer, text_encoder, unet, scheduler, num_inference_steps=50, early_stopping=None, seed=1, samples_per_prompt=1):
    latents, text_embeddings = prepare_diffusion_inputs(prompts, tokenizer, text_encoder, unet, guidance_scale=0, samples_per_prompt=samples_per_prompt, seed=seed)
    scheduler.set_timesteps(num_inference_steps)

    # inject hooks into value layers
    v_handles = []
    v_hooks = []
    for down_block in range(3):
        for attention in range(2):
            v_hook = CollectActivationsLinearNoMean()
            v_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(v_hook)
            v_handles.append(v_handle)
            v_hooks.append(v_hook)
    v_hook = CollectActivationsLinearNoMean()
    v_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(v_hook)
    v_handles.append(v_handle)
    v_hooks.append(v_hook)

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i, t in enumerate(scheduler.timesteps):
            latent_model_input = latents                
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            noise_pred = unet(
                latent_model_input.cuda(),
                t,
                encoder_hidden_states=text_embeddings, return_dict=False)[0]
                                        
            latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
            torch.cuda.empty_cache()
                        
            if early_stopping is not None and i < num_inference_steps - 1:
                break
    
    activations = []
    for hook, handle in zip(v_hooks, v_handles):
        activations.append(hook.activations()[0].abs().mean(dim=0))
        handle.remove()
        
    return activations

# run the denoising process to collect the activations with a hook (has to be added beforehand)
@torch.no_grad()
def collect_activations_deepfloyd(prompts, pipe, num_inference_steps=1, seed=1, samples_per_prompt=1):
    # inject hooks into value layers
    v_handles = []
    v_hooks = []
    for down_block in range(1, 4):
        for attention in range(3):
            v_hook = CollectActivationsLinearNoMean()
            v_handle = pipe.unet.down_blocks[down_block].attentions[attention].to_v.register_forward_hook(v_hook)
            v_handles.append(v_handle)
            v_hooks.append(v_hook)
    v_hook = CollectActivationsLinearNoMean()
    v_handle = pipe.unet.mid_block.attentions[0].to_v.register_forward_hook(v_hook)
    v_handles.append(v_handle)
    v_hooks.append(v_hook)

    generator = torch.manual_seed(seed)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        prompt_embeds, negative_embeds = pipe.encode_prompt(prompts)
        pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, num_inference_steps=num_inference_steps, num_images_per_prompt=samples_per_prompt, generator=generator)
    
    activations = []
    for hook, handle in zip(v_hooks, v_handles):
        activations.append(hook.activations()[0].abs().mean(dim=0))
        handle.remove()
        
    return activations



@torch.no_grad()
def initial_neuron_selection(prompt, tokenizer, text_encoder, unet, scheduler, layer_depth, theta, k, seed=1, version='v1-5'):
    # load statistics from unmemorized LAION prompts
    if version == 'v1-5':
        mean_list, std_list = torch.load('statistics/statistics_additional_laion_prompts_v1_5.pt')
    elif version == 'v1-4':
        mean_list, std_list = torch.load('statistics/statistics_additional_laion_prompts_v1_4.pt')
    elif version == 'v2':
        mean_list, std_list = torch.load('statistics_sdv2_base_v2.pt')
    elif version == 'iclr':
        mean_list, std_list = torch.load('statistics_sdv14_iclr_v1_4.pt')
    elif version == 'DeepFloyd':
        mean_list, std_list = torch.load('statistics/DeepFloyd.pt')
    else:
        mean_list, std_list = torch.load(version)


    # variables to count number of (deactivated) neurons
    deactivated_neurons = 0
    total_neurons = 0

    # compute and collect activations based on OOD detection and top-k absolute activations
    activations_list = collect_activations([prompt], tokenizer, text_encoder, unet, scheduler, num_inference_steps=50, samples_per_prompt=1, early_stopping=1, seed=seed)

    blocking_indices = [[] for i in range(7)]
    for layer_id in range(layer_depth):
        activations = activations_list[layer_id]
        diff = (activations.cpu() - mean_list[layer_id]).abs() / std_list[layer_id]
        indices = (diff > theta).nonzero().flatten().tolist()
        
        topk_indices = activations.abs().topk(k=min(k, len(mean_list[layer_id]))).indices
        indices += [e.item() for e in topk_indices]
        
        total_neurons += activations.shape[0]
        deactivated_neurons += len(indices)
        blocking_indices[layer_id] = indices
            
    return blocking_indices


@torch.no_grad()
def initial_neuron_selection_deep_floyd(prompt, pipe, layer_depth, theta, k, seed=1):
    mean_list, std_list = torch.load('statistics/statistics_additional_laion_prompts_deep_floyd.pt')
    
    # variables to count number of (deactivated) neurons
    deactivated_neurons = 0
    total_neurons = 0

    # compute and collect activations based on OOD detection and top-k absolute activations
    activations_list = collect_activations_deepfloyd([prompt], pipe, num_inference_steps=1, samples_per_prompt=1, seed=seed)

    blocking_indices = [[] for i in range(len(activations_list))]
    for layer_id in range(layer_depth):
        activations = activations_list[layer_id]
        diff = (activations.cpu() - mean_list[layer_id].cpu()).abs() / std_list[layer_id].cpu()
        indices = (diff > theta).nonzero().flatten().tolist()
        
        topk_indices = activations.abs().topk(k=k).indices
        indices += [e.item() for e in topk_indices]
        
        total_neurons += activations.shape[0]
        deactivated_neurons += len(indices)
        blocking_indices[layer_id] = indices
            
    return blocking_indices

@torch.no_grad()
def initial_neuron_selection_output_layer(prompt, tokenizer, text_encoder, unet, scheduler, theta, k, seed=1, layer_indices=[0, 1, 2, 3, 4, 5, 6], blocked_indices=[[]] * 7, version='v1-5'):
    # load statistics from unmemorized LAION prompts
    if version == 'v1-5':
        mean_list, std_list = torch.load('statistics/statistics_additional_laion_prompts.pt')
    elif version == 'v1-4':
        mean_list, std_list = torch.load('statistics/statistics_additional_laion_prompts.pt')

    # variables to count number of (deactivated) neurons
    deactivated_neurons = 0
    total_neurons = 0

    # compute and collect activations based on OOD detection and top-k absolute activations
    activations_list = collect_activations([prompt], tokenizer, text_encoder, unet, scheduler, num_inference_steps=50, samples_per_prompt=1, early_stopping=1, seed=seed)

    blocking_indices = [[] for i in range(7)]
    for layer_id in layer_indices:
        activations = activations_list[layer_id]
        diff = (activations.cpu() - mean_list[layer_id]).abs() / std_list[layer_id]
        indices = (diff > theta).nonzero().flatten().tolist()
        
        topk_indices = activations.topk(k=k).indices
        indices += [e.item() for e in topk_indices]
        
        total_neurons += activations.shape[0]
        deactivated_neurons += len(indices)
        blocking_indices[layer_id] = indices
            
    return blocking_indices


def calculate_max_pairwise_ssim(noise_diffs):    
    pairwise_combination_indices = torch.combinations(torch.arange(len(noise_diffs)), r=2)

    input_1 = noise_diffs[pairwise_combination_indices[:,0]]
    input_2 = noise_diffs[pairwise_combination_indices[:,1]]
    ssim = multiscale_structural_similarity_index_measure(input_1, input_2, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))

    max_ssims = []
    for index in range(len(noise_diffs)):
        max_ssims.append(ssim[(pairwise_combination_indices == index).max(-1).values].max())
    max_ssims = torch.stack(max_ssims)

    return max_ssims


# denoising process to collect noise diffs between the predicted noise and the noise latents from the previous step.
@torch.no_grad()
def compute_noise_diff(prompts, tokenizer, text_encoder, unet, scheduler, blocked_indices, guidance_scale, seed, samples_per_prompt, scaling_factor, num_inference_steps=50, early_stopping=1, seed_indices_to_return=None):
    latents, text_embeddings = prepare_diffusion_inputs(prompts, tokenizer, text_encoder, unet, guidance_scale=guidance_scale, samples_per_prompt=samples_per_prompt, seed=seed)
    scheduler.set_timesteps(num_inference_steps)
    
    if blocked_indices:
        block_handles = []
        block_hooks = []
        for down_block in range(3):
            for attention in range(2):
                indices = blocked_indices[down_block * 2 + attention]
                block_hook = RescaleLinearActivations(indices=indices, factor=scaling_factor)
                block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
                block_handles.append(block_handle)
                block_hooks.append(block_hook)
        block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=scaling_factor)
        block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
        block_handles.append(block_handle)
        block_hooks.append(block_hook)


    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i, t in enumerate(scheduler.timesteps):
            if guidance_scale == 0:
                latent_model_input = latents
            else:
                latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            text_embeddings.requires_grad = False
            latent_model_input.requires_grad = False

            noise_pred = unet(
                latent_model_input.cuda(),
                t,
                encoder_hidden_states=text_embeddings, return_dict=False)[0]
            
            if guidance_scale != 0:
                    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (
                        noise_pred_text - noise_pred_uncond)
                    
            if i >= early_stopping or i == num_inference_steps - 1:
                if blocked_indices:
                    for handle in block_handles:
                        handle.remove()
                                                
                noise_diff = noise_pred - latents
                min_values = noise_diff.amin(dim=[2,3])
                max_values = noise_diff.amax(dim=[2,3])
                noise_diff_scaled = (noise_diff - min_values.unsqueeze(-1).unsqueeze(-1)) / (max_values - min_values).unsqueeze(-1).unsqueeze(-1)

                if seed_indices_to_return is not None:
                    return noise_diff_scaled[seed_indices_to_return]
                
                return noise_diff_scaled


# denoising process to collect noise diffs between the predicted noise and the noise latents from the previous step.
@torch.no_grad()
def compute_noise_diff_deep_floyd(prompts, pipe, blocked_indices, guidance_scale, seed, samples_per_prompt, scaling_factor, num_inference_steps=50, early_stopping=None, seed_indices_to_return=None):
    if blocked_indices:
        block_handles = []
        block_hooks = []
        for down_block in range(1, 4):
            for attention in range(3):
                indices = blocked_indices[(down_block - 1) * 3 + attention]
                block_hook = RescaleLinearActivations(indices=indices, factor=scaling_factor)
                block_handle = pipe.unet.down_blocks[down_block].attentions[attention].to_v.register_forward_hook(block_hook)
                block_handles.append(block_handle)
                block_hooks.append(block_hook)
                
        block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=scaling_factor)
        block_handle = pipe.unet.mid_block.attentions[0].to_v.register_forward_hook(block_hook)
        block_handles.append(block_handle)
        block_hooks.append(block_hook)
        
    # add custom step function to collect noise
    pipe.scheduler.step_default = pipe.scheduler.step
    noise_pred = []
    def step_custom(model_output, timestep, sample, return_dict):
        noise_pred.append(model_output)
        raise StopIteration
    pipe.scheduler.step = step_custom
    
    prompt_embeds, negative_embeds = pipe.encode_prompt(prompts)
    generator = torch.manual_seed(seed)
    try:
        pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=samples_per_prompt)
    except StopIteration:
        pass

    pipe.scheduler.step = pipe.scheduler.step_default
    if blocked_indices:
        for handle in block_handles:
            handle.remove()

    generator = torch.manual_seed(seed)
    latents = pipe.prepare_intermediate_images(batch_size=samples_per_prompt, num_channels=3, height=64, width=64, dtype=torch.float16, device=pipe.device, generator=generator)
               
    noise_pred = noise_pred[0][:, :3]
    noise_diff = noise_pred - latents
    min_values = noise_diff.amin(dim=[2,3])
    max_values = noise_diff.amax(dim=[2,3])
    noise_diff_scaled = (noise_diff - min_values.unsqueeze(-1).unsqueeze(-1)) / (max_values - min_values).unsqueeze(-1).unsqueeze(-1)
    
    
    if seed_indices_to_return is not None:
        return noise_diff_scaled[seed_indices_to_return]

    return noise_diff_scaled


@torch.no_grad()
def neuron_refinement(prompt, tokenizer, text_encoder, unet, scheduler, input_indices, scaling_factor, metric='ssim', threshold=None, samples_per_prompt=8, guidance_scale=0, seed=1, seeds_to_look_at=None, rel_threshold=None):
    noise_diff_vanilla = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, seed=seed, blocked_indices=None, scaling_factor=1, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
    blocking_indices = copy.deepcopy(input_indices)
    active_layers = set(i for i in range(len(blocking_indices)))

    if rel_threshold is not None:
        print('Using relative threshold')
        threshold = rel_threshold

    # 1.) remove all layers with no blocked neurons or neurons without any impact
    total_neurons = 0
    neurons_removed = 0
    noise_diff_all_blocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=blocking_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
    diff_all_blocked = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_all_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))
    for layer_idx, layer_blocked_indices in reversed(list(enumerate(blocking_indices))):
        # unblock all neurons from a specific layer
        if len(layer_blocked_indices) == 0:
            active_layers.remove(layer_idx)
        else:
            # get all list elements except the current layer
            total_neurons += len(layer_blocked_indices)
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx] = []
            noise_diff_blocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
            if metric == 'ssim':
                diff = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))
            
            comparison_value = diff.max()
            if rel_threshold is not None:
                comparison_value = ((diff - diff_all_blocked) / (diff_all_blocked.abs() + 1e-9)).abs().max()
            
            if comparison_value < threshold:
                neurons_removed += len(layer_blocked_indices)
                blocking_indices[layer_idx] = []
                active_layers.remove(layer_idx)
                
    print('Removed the following layers:', set(range(7)) - active_layers, f'with {neurons_removed} neurons.')
    print('Remaining layers:', active_layers, f'with {total_neurons - neurons_removed} neurons.')

    # 2.) check individual neurons in remaining layers
    neurons_removed = 0
    for layer_idx, blocked_indices in reversed(list(enumerate(blocking_indices))):
        if len(blocked_indices) == 0:
            continue
        blocking_indices_copy = copy.deepcopy(blocking_indices)
        for neuron in blocking_indices_copy[layer_idx]:
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx].remove(neuron)
            noise_diff_blocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
            if metric=='ssim':
                diff = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))
            elif metric == 'tv':
                diff = total_variation(noise_diff_blocked, reduction='none')

            comparison_value = diff.max()
            if rel_threshold is not None:
                comparison_value = ((diff - diff_all_blocked) / (diff_all_blocked.abs() + 1e-9)).abs().max()

            if comparison_value < threshold:
                neurons_removed += 1
                blocking_indices[layer_idx].remove(neuron)

    print(f'Removed {neurons_removed} neurons.')
    print(f'Remaining neurons: {blocking_indices}')
    return blocking_indices


@torch.no_grad()
def neuron_refinement_deep_floyd(prompt, pipe, input_indices, scaling_factor, metric='ssim', threshold=None, samples_per_prompt=8, guidance_scale=0, seed=1, seeds_to_look_at=None, rel_threshold=None):
    noise_diff_vanilla = compute_noise_diff_deep_floyd([prompt], pipe, seed=seed, blocked_indices=None, scaling_factor=1, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
    blocking_indices = copy.deepcopy(input_indices)
    active_layers = set(i for i in range(len(blocking_indices)))

    if rel_threshold is not None:
        print('Using relative threshold')
        threshold = rel_threshold

    # 1.) remove all layers with no blocked neurons or neurons without any impact
    total_neurons = 0
    neurons_removed = 0
    noise_diff_all_blocked = compute_noise_diff_deep_floyd([prompt], pipe, blocked_indices=blocking_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
    diff_all_blocked = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_all_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))
    for layer_idx, layer_blocked_indices in reversed(list(enumerate(blocking_indices))):
        # unblock all neurons from a specific layer
        if len(layer_blocked_indices) == 0:
            active_layers.remove(layer_idx)
        else:
            # get all list elements except the current layer
            total_neurons += len(layer_blocked_indices)
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx] = []
            noise_diff_blocked = compute_noise_diff_deep_floyd([prompt], pipe, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
            if metric == 'ssim':
                diff = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))

            comparison_value = diff.max()
            if rel_threshold is not None:
                comparison_value = ((diff - diff_all_blocked) / (diff_all_blocked.abs() + 1e-9)).abs().max()

            if comparison_value < threshold:
                neurons_removed += len(layer_blocked_indices)
                blocking_indices[layer_idx] = []
                active_layers.remove(layer_idx)
                
    print('Removed the following layers:', set(range(7)) - active_layers, f'with {neurons_removed} neurons.')
    print('Remaining layers:', active_layers, f'with {total_neurons - neurons_removed} neurons.')

    # 2.) check individual neurons in remaining layers
    neurons_removed = 0
    for layer_idx, blocked_indices in reversed(list(enumerate(blocking_indices))):
        if len(blocked_indices) == 0:
            continue
        blocking_indices_copy = copy.deepcopy(blocking_indices)
        for neuron in blocking_indices_copy[layer_idx]:
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx].remove(neuron)
            noise_diff_blocked = compute_noise_diff_deep_floyd([prompt], pipe, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50, seed_indices_to_return=seeds_to_look_at)
            if metric=='ssim':
                diff = multiscale_structural_similarity_index_measure(noise_diff_vanilla, noise_diff_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33))
            elif metric == 'tv':
                diff = total_variation(noise_diff_blocked, reduction='none')

            comparison_value = diff.max()
            if rel_threshold is not None:
                comparison_value = ((diff - diff_all_blocked) / (diff_all_blocked.abs() + 1e-9)).abs().max()

            if comparison_value < threshold:
                neurons_removed += 1
                blocking_indices[layer_idx].remove(neuron)

    print(f'Removed {neurons_removed} neurons.')
    print(f'Remaining neurons: {blocking_indices}')
    return blocking_indices

def calculate_cos_sim_per_layer(activations_unblocked, activations_blocked):
    cos_sims = []
    for current_activations_unblocked, current_activations_blocked in zip(activations_unblocked, activations_blocked):
        # take the mean here because we want to find the mean cosine similarity of the different 
        # initial seeds between blocked and unblocked activations
        cos_sim = pairwise_cosine_similarity(current_activations_unblocked, current_activations_blocked)
        cos_sims.append(cos_sim.diag().mean())

    return torch.stack(cos_sims)


@torch.no_grad()
def neuron_refinement_output_layer(prompt, tokenizer, text_encoder, unet, scheduler, input_indices, scaling_factor, metric='ssim', layerwise_rel_threshold=[0.014, 0.015, 0.031, 0.013, 0.009, 0.008, 0.005], samples_per_prompt=8, guidance_scale=0, seed=1):
    output_activations_unblocked = compute_output_layer_activation([prompt], tokenizer, text_encoder, unet, scheduler, seed=seed, blocked_indices=None, scaling_factor=1, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50)
    blocking_indices = copy.deepcopy(input_indices)
    active_layers = set(i for i in range(len(blocking_indices)))

    # 1.) remove all layers with no blocked neurons or neurons without any impact
    total_neurons = 0
    neurons_removed = 0
    output_activations_all_blocked = compute_output_layer_activation([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=blocking_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50)
    cos_sim_all_blocked = calculate_cos_sim_per_layer(output_activations_unblocked, output_activations_all_blocked)
    for layer_idx, layer_blocked_indices in reversed(list(enumerate(blocking_indices))):
        # unblock all neurons from a specific layer
        if len(layer_blocked_indices) == 0:
            active_layers.remove(layer_idx)
        else:
            # get all list elements except the current layer
            total_neurons += len(layer_blocked_indices)
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx] = []
            output_activations_blocked = compute_output_layer_activation([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50)
            
            # calculate the minimum cosine simlarity between the unblocked and the blocked outputs for the current layer
            cos_sim_current_layer =  calculate_cos_sim_per_layer(output_activations_unblocked, output_activations_blocked)[layer_idx]

            # if the cosine similarity for the current layer has changed less than the relative threshold, remove the neuron from the blocking list
            if (cos_sim_current_layer - cos_sim_all_blocked[layer_idx]) / cos_sim_all_blocked[layer_idx].abs() < layerwise_rel_threshold[layer_idx]:
                neurons_removed += len(layer_blocked_indices)
                blocking_indices[layer_idx] = []
                active_layers.remove(layer_idx)
                
    print('Removed the following layers:', set(range(7)) - active_layers, f'with {neurons_removed} neurons.')
    print('Remaining layers:', active_layers, f'with {total_neurons - neurons_removed} neurons.')

    # recalculate the cosine sim for all blocked neurons without the layers we just have removed
    output_activations_all_blocked = compute_output_layer_activation([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=blocking_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50)
    cos_sim_all_blocked = calculate_cos_sim_per_layer(output_activations_unblocked, output_activations_all_blocked)

    # 2.) check individual neurons in remaining layers
    neurons_removed = 0
    for layer_idx, blocked_indices in reversed(list(enumerate(blocking_indices))):
        if len(blocked_indices) == 0:
            continue
        blocking_indices_copy = copy.deepcopy(blocking_indices)
        for neuron in blocking_indices_copy[layer_idx]:
            curr_indices = copy.deepcopy(blocking_indices)
            curr_indices[layer_idx].remove(neuron)

            output_activations_blocked = compute_output_layer_activation([prompt], tokenizer, text_encoder, unet, scheduler, blocked_indices=curr_indices, scaling_factor=scaling_factor, seed=seed, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=50)
            
            # calculate the minimum cosine simlarity between the unblocked and the blocked outputs for the current layer
            cos_sim_current_layer = calculate_cos_sim_per_layer(output_activations_unblocked, output_activations_blocked)[layer_idx]
            
            if (cos_sim_current_layer - cos_sim_all_blocked[layer_idx]) / cos_sim_all_blocked[layer_idx].abs() < layerwise_rel_threshold[layer_idx]:
                neurons_removed += 1
                blocking_indices[layer_idx].remove(neuron)

    print(f'Removed {neurons_removed} neurons.')
    print(f'Remaining neurons: {blocking_indices}')
    return blocking_indices

# denoising process to collect noise diffs between the predicted noise and the noise latents from the previous step.
@torch.no_grad()
def compute_output_layer_activation(prompts, tokenizer, text_encoder, unet, scheduler, blocked_indices, guidance_scale, seed, samples_per_prompt, scaling_factor, num_inference_steps=50, early_stopping=1):
    latents, text_embeddings = prepare_diffusion_inputs(prompts, tokenizer, text_encoder, unet, guidance_scale=guidance_scale, samples_per_prompt=samples_per_prompt, seed=seed)
    scheduler.set_timesteps(num_inference_steps)
    
    if blocked_indices is not None:
        block_handles = []
        block_hooks = []
        for down_block in range(3):
            for attention in range(2):
                indices = blocked_indices[down_block * 2 + attention]
                block_hook = RescaleLinearActivations(indices=indices, factor=scaling_factor)
                block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
                block_handles.append(block_handle)
                block_hooks.append(block_hook)
        block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=scaling_factor)
        block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
        block_handles.append(block_handle)
        block_hooks.append(block_hook)

    # TODO: careful!!! at the moment the CollectActivationsValueOutputLayer hook is only implemented for a single step. Doing multiple steps will overwrite 
    # the activations and only return the last step.

    # add the hooks to get the output layer activations of the down blocks
    output_handles = []
    output_hooks = []
    for down_block in range(3):
        for attention in range(2):
            output_hook = CollectActivationsValueOutputLayer(unconditional=guidance_scale==0)
            output_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_out[0].register_forward_hook(output_hook)
            output_handles.append(output_handle)
            output_hooks.append(output_hook)

    # also add the activation hooks to the middle block
    output_hook = CollectActivationsValueOutputLayer(unconditional=guidance_scale==0)
    output_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_out[0].register_forward_hook(output_hook)
    output_handles.append(output_handle)
    output_hooks.append(output_hook)
    

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i, t in enumerate(scheduler.timesteps):
            if guidance_scale == 0:
                latent_model_input = latents
            else:
                latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            text_embeddings.requires_grad = False
            latent_model_input.requires_grad = False

            noise_pred = unet(
                latent_model_input.cuda(),
                t,
                encoder_hidden_states=text_embeddings, return_dict=False)[0]
            
            if guidance_scale != 0:
                    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (
                        noise_pred_text - noise_pred_uncond)
                    
            if i >= early_stopping or i == num_inference_steps - 1:
                if blocked_indices:
                    for handle in block_handles:
                        handle.remove()

                # remove the activation hooks
                for handle in output_handles:
                        handle.remove()

                layer_wise_outputs = []
                for hook in output_hooks:
                        activations = hook.activations()
                        layer_wise_outputs.append(torch.nn.functional.normalize(activations.flatten(1)))
                                                
                return layer_wise_outputs

# denoising process to collect noise magnitudes between the predicted noise and the unconditioned noise. From the ICLR paper (Wen et al.)
@torch.no_grad()
def compute_noise_norm(prompts, tokenizer, text_encoder, unet, scheduler, seed=1, samples_per_prompt=1, num_inference_steps=50, early_stopping=1):
    with torch.no_grad():
        height = 512
        width = 512
        generator = torch.manual_seed(seed)
        if samples_per_prompt > 1:
            prompts = [prompt for prompt in prompts for _ in range(samples_per_prompt)]             
        text_input = tokenizer(prompts,
                                padding="max_length",
                                max_length=tokenizer.model_max_length,
                                truncation=True,
                                return_tensors="pt")
        text_embeddings = text_encoder(
            text_input.input_ids.to(text_encoder.device))[0]

        latents = torch.randn(
            (len(prompts), unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
        max_length = text_input.input_ids.shape[-1]   
        uncond_input = tokenizer([""] * len(prompts),
                                    padding="max_length",
                                    max_length=max_length,
                                    return_tensors="pt")
        uncond_embeddings = text_encoder(
            uncond_input.input_ids.to(text_encoder.device))[0]
        text_embeddings = torch.cat([text_embeddings, uncond_embeddings])

        latents = latents.to(text_embeddings.device)
        scheduler.set_timesteps(num_inference_steps)

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i, t in enumerate(scheduler.timesteps):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            text_embeddings.requires_grad = False
            latent_model_input.requires_grad = False

            noise_pred = unet(
                latent_model_input.cuda(),
                t,
                encoder_hidden_states=text_embeddings, return_dict=False)[0]
                      
            if i >= early_stopping or i == num_inference_steps - 1:         
                noise_prediction_text = noise_pred[samples_per_prompt:] - noise_pred[:samples_per_prompt]   
                noise_norm_conditional = torch.norm(noise_prediction_text.flatten(1), dim=1, p=2)              
                return noise_norm_conditional             
