import argparse
import csv
import glob
import json
import os
import sys
from rtpt import RTPT
import pandas as pd
import torch
from torchmetrics.functional import multiscale_structural_similarity_index_measure
from tqdm import tqdm
import re

from utils.activation_detection import calculate_max_pairwise_ssim, compute_noise_diff, initial_neuron_selection, neuron_refinement
from utils.adv_embedding import find_adv_text_embeddings
from utils.mitigations import MemMitigation, Nemo, Wanda
from utils.stable_diffusion import generate_images, load_sd_components, load_text_components
from utils.wanda import get_input_norms, get_masking_matrices

class NotMemorizedError(Exception):
    """Exception raised when a prompt is not memorized."""
    def __init__(self, message="Prompt Not Memorized."):
        self.message = message
        super().__init__(self.message)

def str_to_list(s):
    pattern = re.compile(r'\[.*?\]')
    sublists = pattern.findall(s)
    return [list(map(int, re.findall(r'\d+', sublist))) for sublist in sublists]

def main():
    args = create_parser()

    if args.continue_run:
        print(f"Continuing run from index {args.start_index} to {args.end_index}. Saving the results to {args.output_path}...")
    else:
        # save the args
        os.makedirs(args.output_path, exist_ok=False)
        os.makedirs(os.path.join(args.output_path, 'adversarial_images'), exist_ok=False)
        os.makedirs(os.path.join(args.output_path, 'mitigated_images'), exist_ok=False)
        os.makedirs(os.path.join(args.output_path, 'no_mitigation'), exist_ok=False)
        with open(os.path.join(args.output_path, "config.json"), "w") as outfile:
            args_to_save = vars(args)
            args_to_save['command'] = " ".join(sys.argv)
            json.dump(args_to_save, outfile)
        
    
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)
    vae.requires_grad_(False)
    text_encoder.to(torch_device)
    text_encoder.requires_grad_(False)
    unet.to(torch_device)
    unet.requires_grad_(False)

    assert args.method in ['nemo', 'wanda'], f"Method {args.method} not supported. Use one of [nemo, wanda]"

    # load prompts
    mem_prompts_df = pd.read_csv(args.prompts, sep=';')

    # load the memorized images and get their paths
    image_paths = sorted(glob.glob(f'{args.memorized_images}/*.png'))
    mem_prompts_df['image_path'] = image_paths

    # filter for vm or tm prompts
    if args.memorization_type is not None:
        mem_prompts_df = mem_prompts_df[mem_prompts_df['type'] == args.memorization_type.upper()]
        if len(mem_prompts_df) == 0:
            print(f"No prompts found for the type {args.memorization_type}. Use one of [VM, TM]")
            return
        else:
            print(f'Only taking neurons of {args.memorization_type.upper()} prompts, {len(mem_prompts_df)} results remaining')

    mem_prompts_df = mem_prompts_df.loc[args.start_index:args.end_index]
    
    rtpt = RTPT(args.user, 'Iterative_Mitigation', len(mem_prompts_df) * args.adv_iterations)
    rtpt.start()

    if args.method == 'nemo':
        # create a csv writer to log the blocked neurons
        header = ['Caption', 'Blocked Neurons', 'Total Number of Neurons Blocked', 'Index', 'Adv. Iteration']
        file_path = os.path.join(args.output_path, f'nemo_blocked_neurons_{args.start_index}_{args.end_index}.csv')
        if os.path.exists(file_path):
            raise FileExistsError(f"File {file_path} already exists. Please remove it or change the name.")
        
        file = open(file_path, "w")
        writer = csv.DictWriter(file, delimiter=';', fieldnames=header)
        writer.writeheader()

    for i, row in tqdm(mem_prompts_df.iterrows(), total=len(mem_prompts_df)):
        prompt = row['Caption']
        memorized_image_path = row['image_path']

        generated_imgs = generate_images(
            prompt, 
            tokenizer, 
            text_encoder,
            vae,
            unet,  
            scheduler, 
            guidance_scale=args.guidance_scale, 
            seed=args.seed, 
            samples_per_prompt=args.num_samples, 
            num_inference_steps=args.num_steps,
            text_embeddings=None
        )
        for j in range(len(generated_imgs)):
            generated_imgs[j].save(f"{args.output_path}/no_mitigation/{i:04d}_{j:02d}.png")
        
        mitigation_method = None
        last_mitigation_method_instance = None
        adv_text_emb = None
        for adv_iter in range(args.adv_iterations):
            try:
                mitigation_method: MemMitigation = get_mitigation_method(
                    prompt, 
                    unet, 
                    tokenizer, 
                    text_encoder, 
                    scheduler, 
                    args,
                    text_embedding=adv_text_emb,
                    from_file=args.result_file if adv_iter == 0 else None
                )
            except NotMemorizedError as e:
                print(f"Prompt {prompt} seems to be not memorized, skipping this sample. Prompt: {prompt}\t Adv. Iteration: {adv_iter}\t Error: {e}")
                # create empty image .png dummy files for the skipped samples
                for j in range(args.num_samples):
                    open(f"{args.output_path}/adversarial_images/{i:04d}_{j:02d}_{adv_iter:02d}.png", 'w').close()
                    open(f"{args.output_path}/mitigated_images/{i:04d}_{j:02d}_{adv_iter:02d}.png", 'w').close()
                continue
            
            if last_mitigation_method_instance is not None:
                # merge the identified neurons with the last mitigation method instance
                mitigation_method = merge_identified_neurons(mitigation_method, last_mitigation_method_instance)
                # remove the last mitigation method instance
                last_mitigation_method_instance.remove()

            mitigation_method.apply()
            # log how many neurons were blocked to a csv file with the prompt index
            if args.method == 'nemo':
                num_neurons_blocked = sum([len(layer) for layer in mitigation_method.blocked_indices])
                ['Caption', 'Blocked Neurons', 'Total Number of Neurons Blocked', 'Index']
                writer.writerow({
                    'Caption': prompt, 
                    'Blocked Neurons': mitigation_method.blocked_indices, 
                    'Total Number of Neurons Blocked': num_neurons_blocked, 
                    'Index': i,
                    'Adv. Iteration': adv_iter
                })
                file.flush()

            # perform adv. embedding search
            adv_text_emb = find_adv_text_embeddings(
                img_path=memorized_image_path, 
                unet=unet, 
                tokenizer=tokenizer, 
                text_encoder=text_encoder, 
                vae=vae, 
                scheduler=scheduler, 
                prompt=prompt, 
                num_steps=args.adv_steps, 
                batch_size=args.adv_batch_size,
                seed=args.seed,
                lr=args.adv_lr
            )
            # generate images (with mitigation and with the adv. embeddings)
            adv_generated_imgs = generate_images(
                None, 
                tokenizer, 
                text_encoder,
                vae,
                unet,  
                scheduler, 
                guidance_scale=args.guidance_scale, 
                seed=args.seed, 
                samples_per_prompt=args.num_samples, 
                num_inference_steps=args.num_steps,
                text_embeddings=adv_text_emb,
            )
            generated_imgs = generate_images(
                prompt, 
                tokenizer, 
                text_encoder,
                vae,
                unet,  
                scheduler, 
                guidance_scale=args.guidance_scale, 
                seed=args.seed, 
                samples_per_prompt=args.num_samples, 
                num_inference_steps=args.num_steps,
                text_embeddings=None
            )

            # update the last mitigation method instance
            last_mitigation_method_instance = mitigation_method

            # save the images generated with the adv. embeddings
            for j in range(len(generated_imgs)):
                adv_generated_imgs[j].save(f"{args.output_path}/adversarial_images/{i:04d}_{j:02d}_{adv_iter:02d}.png")
                generated_imgs[j].save(f"{args.output_path}/mitigated_images/{i:04d}_{j:02d}_{adv_iter:02d}.png")

            rtpt.step()

        # remove all the mitigations
        if mitigation_method is not None:
            mitigation_method.remove()
        if last_mitigation_method_instance is not None:
            last_mitigation_method_instance.remove()
    
    # close the csv file
    if args.method == 'nemo':
        file.close()


def merge_identified_neurons(method1, method2):
    
    # assert that the methods are of the same type
    assert type(method1) == type(method2), "Methods are not of the same type"

    # merge the neurons
    if type(method1) == Nemo:
        # merge the neurons
        merged_blocked_indices = []
        for layer_idx in range(len(method1.blocked_indices)):
            merged_layer_neurons = method1.blocked_indices[layer_idx] + method2.blocked_indices[layer_idx]
            sorted(merged_layer_neurons)
            merged_blocked_indices.append(merged_layer_neurons)
        method1.blocked_indices = merged_blocked_indices
    elif type(method1) == Wanda:
        new_masking_matrices = {}
        for layer_key in method1.masking_matrices.keys():
            merged_mask = ~((method1.masking_matrices[layer_key] == 0) | (method2.masking_matrices[layer_key] == 0))
            new_masking_matrices[layer_key] = merged_mask.to(method1.masking_matrices[layer_key].dtype)

        method1.masking_matrices = new_masking_matrices

    return method1


def get_mitigation_method(
    prompt, 
    unet,
    tokenizer, 
    text_encoder,
    scheduler,
    args,
    text_embedding=None,
    from_file=None,
):
    if args.method == 'nemo':
        if from_file is not None:
            print(f"Loading the blocked neurons from the file {from_file}...")
            # load the blocked neurons from the file
            blocked_neurons_df = pd.read_csv(args.result_file, sep=';')
            blocked_neurons_df = blocked_neurons_df[blocked_neurons_df['Caption'] == prompt]
            if len(blocked_neurons_df) == 0:
                raise NotMemorizedError(f"Prompt {prompt} seems to be not memorized, no neurons to block found.")
            blocked_neurons = str_to_list(blocked_neurons_df.iloc[0]['Refined Neurons'])

            return Nemo(unet, blocked_indices=blocked_neurons, scaling_factor=args.scaling_factor)

        blocked_neurons = nemo_get_neurons(
            prompt=prompt,
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            seed=args.seed - 1,
            scaling_factor=args.scaling_factor,
            samples_per_prompt=args.samples_per_prompt,
            guidance_scale=0,
            num_inference_steps=args.num_steps,
            pairwise_ssim_threshold=args.pairwise_ssim_threshold,
            initial_theta=args.initial_theta,
            initial_k=args.initial_k,
            ssim_threshold_initial_selection=args.ssim_threshold_initial_selection,
            ssim_threshold_refinement=args.ssim_threshold_refinement,
            theta_reduction_per_step=args.theta_reduction_per_step,
            min_theta=args.min_theta,
            version=args.version,
            text_embedding=text_embedding
        )
        return Nemo(unet, blocked_indices=blocked_neurons, scaling_factor=args.scaling_factor)
    elif args.method == 'wanda':
        uncond_input_norms, cond_input_norms = get_input_norms(
            prompts=[prompt],
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            guidance_scale=args.guidance_scale,
            seed=args.seed - 1,
            samples_per_prompt=args.samples_per_prompt,
            num_inference_steps=args.num_steps,
            blocks=[True] * 16,
            early_stopping=args.timesteps_used,
            verbose=False
        )
        masking_matrices = get_masking_matrices(
            unet, 
            uncond_input_norms, 
            cond_input_norms, 
            percentage_of_neurons_to_prune=args.sparsity,
            timesteps_used=args.timesteps_used,
            verbose=False
        )
        return Wanda(unet, masking_matrices)
    else:
        raise NotImplementedError(f"Method {args.method} not implemented")

def nemo_get_neurons(
        prompt,
        tokenizer,
        text_encoder,
        unet,
        scheduler,
        seed,
        scaling_factor,
        samples_per_prompt,
        guidance_scale,
        num_inference_steps,
        pairwise_ssim_threshold,
        initial_theta,
        initial_k,
        ssim_threshold_initial_selection,
        ssim_threshold_refinement,
        theta_reduction_per_step,
        min_theta,
        version,
        text_embedding=None,
):
    # find the initial selection of blocked neurons
    noise_diff_unblocked = compute_noise_diff(
        [prompt], 
        tokenizer, 
        text_encoder, 
        unet, scheduler, 
        seed=seed, 
        blocked_indices=None, 
        scaling_factor=scaling_factor, 
        samples_per_prompt=samples_per_prompt, 
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps,
        text_embedding=text_embedding
    )

    # get the samples to look at
    max_ssims_per_noise_diff = calculate_max_pairwise_ssim(noise_diff_unblocked)
    sample_indices_to_look_at = max_ssims_per_noise_diff > pairwise_ssim_threshold
    noise_diff_unblocked = noise_diff_unblocked[sample_indices_to_look_at]

    # if there are SSIM values that are above the threshold skip this sample as it seems to be not memorized
    if sample_indices_to_look_at.sum() == 0:
        raise NotMemorizedError(
            f"Prompt {prompt} seems to be not memorized, no neurons to block found. "
            f"SSIM values: {max_ssims_per_noise_diff}"
        )

    # perform the initial selection of neurons
    ssim = 1.0
    theta = initial_theta
    layer_depth = 7
    k=initial_k
    refinement_ssim_threshold = ssim_threshold_refinement
    while ssim > ssim_threshold_initial_selection:
        blocked_indices = initial_neuron_selection(
            prompt, 
            tokenizer, 
            text_encoder, 
            unet, 
            scheduler, 
            layer_depth=layer_depth, 
            theta=theta,
            k=k, 
            seed=seed, 
            version=version,
            text_embedding=text_embedding
        )
        noise_diff_blocked = compute_noise_diff(
            [prompt], 
            tokenizer, 
            text_encoder, 
            unet, 
            scheduler, 
            seed=seed, 
            blocked_indices=blocked_indices, 
            scaling_factor=scaling_factor, 
            samples_per_prompt=samples_per_prompt, 
            guidance_scale=guidance_scale, 
            num_inference_steps=num_inference_steps, 
            seed_indices_to_return=sample_indices_to_look_at,
            text_embedding=text_embedding
        )
        ssim = multiscale_structural_similarity_index_measure(
            noise_diff_unblocked, 
            noise_diff_blocked, 
            reduction='none', 
            kernel_size=11, 
            betas=(0.33, 0.33, 0.33)
        ).max()
            
        if ssim > ssim_threshold_initial_selection:
            if theta > min_theta:
                theta = theta - theta_reduction_per_step
        else:
            print(f'Initial selection of blocked neurons found with theta={theta} and k={k}. SSIM={ssim}')

        if theta == 1 or k >= 1280:
            refinement_ssim_threshold = ssim
            break

    # perform the refinement step
    refined_blocking_indices = neuron_refinement(
        prompt, 
        tokenizer, 
        text_encoder, 
        unet, 
        scheduler, 
        input_indices=blocked_indices, 
        scaling_factor=scaling_factor, 
        threshold=refinement_ssim_threshold, 
        rel_threshold=None, 
        samples_per_prompt=samples_per_prompt, 
        guidance_scale=guidance_scale, 
        seed=seed, 
        seeds_to_look_at=sample_indices_to_look_at,
        text_embedding=text_embedding
    )

    return refined_blocking_indices


def create_parser():
    parser = argparse.ArgumentParser(description='Generating images')
    parser.add_argument(
        '--prompts', 
        default='prompts/memorized_laion_prompts.csv', 
        type=str, 
        help='Path to the prompts file (default: \'prompts/memorized_laion_prompts.csv\')'
    )
    parser.add_argument(
        '-o',
        '--output',
        default='generated_images/iterative_mitigation',
        type=str,
        dest="output_path",
        help=
        'output folder for generated images (default: \'generated_images/wanda\')')
    parser.add_argument('-s',
                        '--seed',
                        default=2,
                        type=int,
                        dest="seed",
                        help='seed for generated images (default: 2')
    parser.add_argument(
        '-n',
        '--num_samples',
        default=10,
        type=int,
        dest="num_samples",
        help='number of generated samples for each prompt (default: 10)')
    parser.add_argument('--steps',
                        default=50,
                        type=int,
                        dest="num_steps",
                        help='number of denoising steps (default: 50)')
    parser.add_argument('-g',
                        '--guidance_scale',
                        default=7,
                        type=float,
                        dest="guidance_scale",
                        help='guidance scale (default: 7)')
    parser.add_argument('-u',
                        '--user',
                        default='XX',
                        type=str,
                        dest="user",
                        help='name initials for RTPT (default: "XX")')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('--memorized_images', 
                        default="prompts/images_memorized_orig_with_mem", 
                        type=str, 
                        help='Path to the memorized images (default: \'prompts/images_memorized_orig_with_mem\')'
    )
    parser.add_argument('--memorization_type', 
                        default='vm', 
                        type=str, 
                        help='Decide if the neurons of the verbatim or template prompts should be used. [vm, tm]'
    )
    parser.add_argument('--method', default='nemo', type=str, help='Method to use for neuron blocking [nemo, wanda]')

    # adv. optimization specific arguments
    parser.add_argument('--adv_lr', default=0.1, type=float, help='learning rate for adv embedding optimization')
    parser.add_argument('--adv_steps', default=50, type=int, help='Number of optimization steps for adv embedding')
    parser.add_argument('--adv_batch_size', default=8, type=int, help='Batch size for adv embedding optimization')
    parser.add_argument('--adv_iterations', default=5, type=int, help='Number of iterations of mitigation after adv. search')
    parser.add_argument('--continue_run', action='store_true', help='Continue run from the last started prompt index')
    parser.add_argument('--start_index', default=0, type=int, help='Start index for the run (default: 0)')
    parser.add_argument('--end_index', default=500, type=int, help='End index for the run (default: 0)')

    # nemo specific arguments
    parser.add_argument('--scaling_factor', default=0, type=float, help='Scaling factor for the blocking of neurons')
    parser.add_argument('--samples_per_prompt', default=10, type=int, help='Number of samples generated per prompt (default: 10)')
    parser.add_argument('--pairwise_ssim_threshold', default=0.428, type=float, help='Threshold for the pairwise SSIM for choosing at which initial samples to look at (default: 0.428)')
    parser.add_argument('--ssim_threshold_initial_selection', default=0.428, type=float, help='SSIM threshold for the initial neuron selection (default: 0.428)')
    parser.add_argument('--ssim_threshold_refinement', default=0.428, type=float, help='SSIM threshold for the neuron refinement (default: 0.428)')
    parser.add_argument('--initial_theta', default=5, type=float, help='The initial theta value for the initial neuron selection (default: 5)')
    parser.add_argument('--initial_k', default=0, type=int, help='The initial k value for the initial neuron selection (default: 0)')
    parser.add_argument('--min_theta', default=1, type=float, help='The minimum theta value for the initial neuron selection (default: 1)')
    parser.add_argument('--theta_reduction_per_step', default=0.25, type=float, help='The reduction of theta value per step during the initial neuron selection(default: 0.25)')
    parser.add_argument('--result_file', default='memorization_statistics_v1_4.csv', type=str, help='path to file with image descriptions (default: results/memorization_statistics_v1_4.csv)')

    # wanda specific arguments
    parser.add_argument('--timesteps_used', default=10, type=int, help='The number of timesteps used for the masking matrices (default: 1).')
    parser.add_argument('--sparsity', default=0.01, type=float, help='The percentage of neurons to prune (default: 0.01).')

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    main()