import argparse
import glob
from rtpt import RTPT
import pandas as pd
import os
import json
import sys
import pickle
from tqdm import tqdm

from utils.adv_embedding import find_adv_text_embeddings
from utils.mitigations import Wanda
from utils.stable_diffusion import generate_images, load_sd_components, load_text_components
from utils.wanda import get_masking_matrices


def main():
    args = create_parser()

    # save the args
    os.makedirs(args.output_path, exist_ok=False)
    with open(os.path.join(args.output_path, "config.json"), "w") as outfile:
        args_to_save = vars(args)
        args_to_save['command'] = " ".join(sys.argv)
        json.dump(args_to_save, outfile)

    # load prompts
    prompts_df = pd.read_csv(args.prompts, sep=';')

    image_paths = sorted(glob.glob(f'{args.memorized_images}/*.png'))
    prompts_df['image_path'] = image_paths
    # prompts_df['image_path'] = prompts_df['image_path'].apply(lambda x: os.path.join(args.memorized_images, x))
    # for idx, row in prompts_df.iterrows():
    #     assert int(row['image_path'].split('_')[-1].split('.')[0]) == row['Index']

    rtpt = RTPT(args.user, 'Wanda_Gen_Attack_Imgs', len(prompts_df))
    rtpt.start()

    # Load SD components
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)    
    vae.eval()
    text_encoder.to(torch_device)
    text_encoder.eval()
    unet.to(torch_device)
    unet.eval()

    # filter for vm or tm prompts
    if args.memorization_type is not None:
        df_original_prompts = pd.read_csv('prompts/memorized_laion_prompts.csv', sep=';')
        assert len(prompts_df) == len(df_original_prompts)
        prompts_df = prompts_df[df_original_prompts['type'] == args.memorization_type.upper()]
        if len(prompts_df) == 0:
            print(f"No prompts found for the type {args.memorization_type}. Use one of [VM, TM]")
            return
        else:
            print(f'Only taking neurons of {args.memorization_type.upper()} prompts, {len(prompts_df)} results remaining')

    # load the input norms
    if not args.calculate_norm_per_prompt:
        with open(args.input_norm_path, 'rb') as f:
            uncond_input_norms, cond_input_norms = pickle.load(f)

        masking_matrices = get_masking_matrices(
            unet, 
            uncond_input_norms, 
            cond_input_norms, 
            percentage_of_neurons_to_prune=args.sparsity,
            timesteps_used=args.timesteps_used 
        )

    for i in tqdm(range(len(prompts_df)), total=len(prompts_df)):
        rows = prompts_df.iloc[i]
        prompt = rows['Caption']
        memorized_image_path = rows['image_path']

        if args.calculate_norm_per_prompt:
            # load the input norms
            with open(os.path.join(args.input_norm_path, f"{rows['Index']}.pkl"), 'rb') as f:
                uncond_input_norms, cond_input_norms = pickle.load(f)

            masking_matrices = get_masking_matrices(
                unet, 
                uncond_input_norms, 
                cond_input_norms, 
                percentage_of_neurons_to_prune=args.sparsity,
                timesteps_used=args.timesteps_used,
                verbose=False
            )

        wanda = Wanda(unet, masking_matrices)
        wanda.apply()

        text_emb = find_adv_text_embeddings(
            img_path=memorized_image_path, 
            unet=unet, 
            tokenizer=tokenizer, 
            text_encoder=text_encoder, 
            vae=vae, 
            scheduler=scheduler, 
            prompt=prompt if not args.randomized_start else None, 
            num_steps=args.optim_steps, 
            batch_size=args.adv_batch_size,
            seed=args.seed,
            lr=args.lr
        )

        generated_imgs = generate_images(
            None, 
            tokenizer, 
            text_encoder,
            vae,
            unet,  
            scheduler, 
            guidance_scale=args.guidance_scale, 
            seed=args.seed, 
            samples_per_prompt=args.num_samples, 
            num_inference_steps=args.steps,
            text_embeddings=text_emb,
        )        
        wanda.remove()
        
        for j in range(len(generated_imgs)):
            generated_imgs[j].save(f"{args.output_path}/img_{i:04d}_{j%args.num_samples:02d}.jpg")

        rtpt.step()

def create_parser():
    parser = argparse.ArgumentParser(description='Generating images')
    parser.add_argument(
        '--prompts', 
        default='prompts/memorized_laion_prompts.csv', 
        type=str, 
        help='Path to the prompts file (default: \'prompts/memorized_laion_prompts.csv\')'
    )
    parser.add_argument(
        '-o',
        '--output',
        default='generated_images/wanda',
        type=str,
        dest="output_path",
        help=
        'output folder for generated images (default: \'generated_images/wanda\')')
    parser.add_argument('-s',
                        '--seed',
                        default=2,
                        type=int,
                        dest="seed",
                        help='seed for generated images (default: 2')
    parser.add_argument(
        '-n',
        '--num_samples',
        default=10,
        type=int,
        dest="num_samples",
        help='number of generated samples for each prompt (default: 10)')
    parser.add_argument('--steps',
                        default=50,
                        type=int,
                        dest="steps",
                        help='number of denoising steps (default: 50)')
    parser.add_argument('-g',
                        '--guidance_scale',
                        default=7,
                        type=float,
                        dest="guidance_scale",
                        help='guidance scale (default: 7)')
    parser.add_argument('-u',
                        '--user',
                        default='XX',
                        type=str,
                        dest="user",
                        help='name initials for RTPT (default: "XX")')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument(
        '--memorized_images', 
        default="prompts/images_memorized_orig_with_mem", 
        type=str, 
        help='Path to the memorized images (default: \'prompts/images_memorized_orig_with_mem\')'
    )
    parser.add_argument('--memorization_type', default=None, type=str, help='Decide if the neurons of the verbatim or template prompts should be used. [vm, tm]')
    
    # wanda specific args
    parser.add_argument('--lr', default=0.1, type=float, help='The learning rate for the text embedding optimization (default: 0.1).')
    parser.add_argument('--optim_steps', default=50, type=int, help='The number of optimization steps (default: 50).')
    parser.add_argument('--adv_batch_size', default=8, type=int, help='The batch size for the optimization (default: 8).')
    parser.add_argument('--randomized_start', action='store_true', default=False, help='If True, the optimization starts from a random point in the latent space. If False, it starts from the memorized image (default: False).')

    parser.add_argument('--calculate_norm_per_prompt', action='store_true', default=False, help='Calculate the input norms for each prompt.')
    parser.add_argument('--input_norm_path', default='wanda/input_norms_all_layers_seed_1_500_prompts.pkl', type=str, help='The file from which the input norms are loaded (default: \'wanda/input_norms_all_layers.pkl\').')
    parser.add_argument('--sparsity', default=0.01, type=float, help='The percentage of neurons to prune (default: 0.01).')
    parser.add_argument('--timesteps_used', default=10, type=int, help='The number of timesteps used for the masking matrices (default: 50).')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()