import argparse
import glob
import torch
from rtpt import RTPT
import pandas as pd
import os
import json
import sys
import pickle
from tqdm import tqdm
import numpy as np

from utils.mitigations import Wanda
from utils.stable_diffusion import load_sd_components, load_text_components, generate_images
from utils.wanda import get_masking_matrices
from utils.wanda import get_input_norms

@torch.no_grad()
def main():
    args = create_parser()

    # layers = args.input_norm_path.split("_")[-2]
    # args.output_path = args.output_path + f'sparsity_{args.sparsity}_timesteps_{args.timesteps_used}'

    # save the args
    os.makedirs(args.output_path, exist_ok=False)
    with open(os.path.join(args.output_path, "config.json"), "w") as outfile:
        args_to_save = vars(args)
        args_to_save['command'] = " ".join(sys.argv)
        json.dump(args_to_save, outfile)

    # load prompts
    if not 'coco' in args.prompts.lower():
        prompts_df = pd.read_csv(args.prompts, sep=';')
    else:
        prompts_df = pd.read_csv(args.prompts, sep=',')

    rtpt = RTPT(args.user, 'Wanda_Gen_Imgs', len(prompts_df) // args.batch_size)
    rtpt.start()

    # Load SD components
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)    
    vae.eval()
    text_encoder.to(torch_device)
    text_encoder.eval()
    unet.to(torch_device)
    unet.eval()

    # filter for vm or tm prompts
    if args.memorization_type is not None:
        df_original_prompts = pd.read_csv('prompts/memorized_laion_prompts.csv', sep=';')
        assert len(prompts_df) == len(df_original_prompts)
        prompts_df = prompts_df[df_original_prompts['type'] == args.memorization_type.upper()]
        if len(prompts_df) == 0:
            print(f"No prompts found for the type {args.memorization_type}. Use one of [VM, TM]")
            return
        else:
            print(f'Only taking neurons of {args.memorization_type.upper()} prompts, {len(prompts_df)} results remaining')

    if not args.calculate_norm_per_prompt and not args.original_images:
        # load the input norms
        with open(args.input_norm_path, 'rb') as f:
            uncond_input_norms, cond_input_norms = pickle.load(f)

        masking_matrices = get_masking_matrices(
            unet, 
            uncond_input_norms, 
            cond_input_norms, 
            percentage_of_neurons_to_prune=args.sparsity,
            timesteps_used=args.timesteps_used,
            verbose=False
        )

    if 'coco' in args.prompts.lower() and args.calculate_norm_per_prompt:
        print('Calculating input norms for coco images. Sampling random input norms...')
        random_gen = np.random.default_rng(args.seed)

    for i in tqdm(range(len(prompts_df) // args.batch_size), total=len(prompts_df) // args.batch_size):
        rows = prompts_df.iloc[i*args.batch_size:(i+1)*args.batch_size]
        prompts = rows['Caption'].to_list()

        if args.calculate_norm_per_prompt and not args.original_images:
            if not 'coco' in args.prompts.lower():
                assert len(rows) == 1, 'You can only calculate the input norms for one prompt at a time'

                # load the input norms
                with open(os.path.join(args.input_norm_path, f"{rows['Index'].values[0]}.pkl"), 'rb') as f:
                    uncond_input_norms, cond_input_norms = pickle.load(f)
            else:
                # if we want to generate coco images just sample random input norms
                paths = glob.glob(os.path.join(args.input_norm_path, "*.pkl"))
                path = random_gen.choice(paths)
                with open(path, 'rb') as f:
                    uncond_input_norms, cond_input_norms = pickle.load(f)

            masking_matrices = get_masking_matrices(
                unet, 
                uncond_input_norms, 
                cond_input_norms, 
                percentage_of_neurons_to_prune=args.sparsity,
                timesteps_used=args.timesteps_used,
                verbose=False
            )

        if not args.original_images:
            wanda = Wanda(unet, masking_matrices)
            wanda.apply()
        generated_imgs = generate_images(
            prompts, 
            tokenizer, 
            text_encoder,
            vae,
            unet,  
            scheduler, 
            guidance_scale=args.guidance_scale, 
            seed=args.seed, 
            samples_per_prompt=args.num_samples, 
            num_inference_steps=args.steps
        )
        if not args.original_images:
            wanda.remove()
        
        for j in range(len(generated_imgs)):
            generated_imgs[j].save(f"{args.output_path}/img_{i*args.batch_size + j // args.num_samples:04d}_{j%args.num_samples:02d}.jpg")

        rtpt.step()

def create_parser():
    parser = argparse.ArgumentParser(description='Generating images')
    parser.add_argument(
        '--prompts', 
        default='prompts/memorized_laion_prompts.csv', 
        type=str, 
        help='Path to the prompts file (default: \'prompts/memorized_laion_prompts.csv\')'
    )
    parser.add_argument(
        '-o',
        '--output',
        default='generated_images/wanda',
        type=str,
        dest="output_path",
        help=
        'output folder for generated images (default: \'generated_images/wanda\')')
    parser.add_argument('-s',
                        '--seed',
                        default=2,
                        type=int,
                        dest="seed",
                        help='seed for generated images (default: 2')
    parser.add_argument(
        '-n',
        '--num_samples',
        default=10,
        type=int,
        dest="num_samples",
        help='number of generated samples for each prompt (default: 10)')
    parser.add_argument('--steps',
                        default=50,
                        type=int,
                        dest="steps",
                        help='number of denoising steps (default: 50)')
    parser.add_argument('-g',
                        '--guidance_scale',
                        default=7,
                        type=float,
                        dest="guidance_scale",
                        help='guidance scale (default: 7)')
    parser.add_argument('-u',
                        '--user',
                        default='XX',
                        type=str,
                        dest="user",
                        help='name initials for RTPT (default: "XX")')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('-b', '--batch_size', default=1, type=int, help='Number of prompts per batch')
    parser.add_argument('--original_images', action='store_true', default=False, help='Generate the original images')
    parser.add_argument('--memorization_type', default=None, type=str, help='Decide if the neurons of the verbatim or template prompts should be used. [vm, tm]')
    
    # wanda specific args
    parser.add_argument('--calculate_norm_per_prompt', action='store_true', default=False, help='Calculate the input norms for each prompt separately')
    parser.add_argument('--input_norm_path', default='wanda/input_norms_all_layers.pkl', type=str, help='The file from which the input norms are loaded (default: \'wanda/input_norms_all_layers.pkl\').')
    parser.add_argument('--sparsity', default=0.01, type=float, help='The percentage of neurons to prune (default: 0.01).')
    parser.add_argument('--timesteps_used', default=10, type=int, help='The number of timesteps used for the masking matrices (default: 50).')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()