import json
import os
import sys
from tqdm import tqdm
from rtpt import RTPT
import argparse
import pandas as pd
import pickle

from utils.stable_diffusion import load_sd_components, load_text_components
from utils.wanda import get_input_norms


def main():
    args = create_parser()

    # load prompts
    prompts = pd.read_csv(args.prompts, sep=';')

    rtpt = RTPT(args.user, 'Wanda_Input_Norms', len(prompts))

    # filter for vm or tm prompts
    if args.memorization_type is not None:
        prompts = prompts[prompts['type'] == args.memorization_type.upper()]
        if len(prompts) == 0:
            print(f"No prompts found for the type {args.memorization_type}. Use one of [VM, TM]")
            return
        else:
            print(f'Only taking neurons of {args.memorization_type.upper()} prompts, {len(prompts)} results remaining')

    if args.sample_prompts is not None:
        prompts = prompts.sample(args.sample_prompts, random_state=args.seed)
        print(f'Sampling {args.sample_prompts} prompts')

    # Load SD components
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)    
    vae.eval()
    text_encoder.to(torch_device)
    text_encoder.eval()
    unet.to(torch_device)
    unet.eval()

    blocks = []
    if args.layers == 'all':
        blocks = [True] * 16
    else:
        max_block_idx = int(args.layers)
        blocks = [True] * max_block_idx + [False] * (16 - max_block_idx)

    # save the args
    os.makedirs(args.output, exist_ok=False)
    with open(os.path.join(args.output, "config.json"), "w") as outfile:
        args_to_save = vars(args)
        args_to_save['command'] = " ".join(sys.argv)
        json.dump(args_to_save, outfile)

    if args.norms_per_prompt:
        print('Calculating input norms per prompt')
        rtpt.start()
        for i in tqdm(range(len(prompts)), total=len(prompts)):
            rows = prompts.iloc[i]
            prompt = rows['Caption']

            input_norms = get_input_norms(
                prompts=[prompt],
                tokenizer=tokenizer,
                text_encoder=text_encoder,
                unet=unet,
                scheduler=scheduler,
                guidance_scale=args.guidance_scale,
                seed=args.seed,
                samples_per_prompt=args.samples_per_prompt,
                num_inference_steps=args.num_steps,
                blocks=blocks,
                verbose=False
            )
            rtpt.step()


            with open(os.path.join(args.output, f'{rows["Index"]}.pkl'), 'wb') as f:
                pickle.dump(input_norms, f)
    else:
        input_norms = get_input_norms(
            prompts=prompts['Caption'].tolist(),
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            guidance_scale=args.guidance_scale,
            seed=args.seed,
            samples_per_prompt=args.samples_per_prompt,
            num_inference_steps=args.num_steps,
            blocks=blocks,
            rtpt=rtpt
        )

        with open(os.path.join(args.output, 'input_norms.pkl'), 'wb') as f:
            pickle.dump(input_norms, f)


def create_parser():
    parser = argparse.ArgumentParser(description='Calculating Wanda Input Norms')
    
    parser.add_argument('--prompts', default='prompts/memorized_laion_prompts.csv', type=str, help='The file from which the prompts are loaded to calculate the statistics (default: \'prompts/additional_laion_prompts.csv\').')
    parser.add_argument('--output', default='wanda/input_norms.pkl', type=str, help='The file to which the input norms are written (default: \'wanda/input_norms.pkl\').')
    parser.add_argument('--layers', default='all', type=str, help='Up to which layer we are getting the layer norms. (default: all layers)')
    parser.add_argument('--norms_per_prompt', action='store_true', help='Calculate norms per prompt (default: False)')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('-u',
                    '--user',
                    default='XX',
                    type=str,
                    dest="user",
                    help='name initials for RTPT (default: "XX")')
    parser.add_argument('--memorization_type', 
                        default=None, 
                        type=str, 
                        help='Decide if the neurons of the verbatim or template prompts should be used. [vm, tm]'
    )
    parser.add_argument('--seed', default=1, type=int, dest="seed", help='seed for random number generator (default: 1)')
    parser.add_argument('--guidance_scale', default=7.0, type=float, dest="guidance_scale", help='guidance scale (default: 7.0)')
    parser.add_argument('--samples_per_prompt', default=8, type=int, dest="samples_per_prompt", help='number of generated samples for each prompt (default: 4)')
    parser.add_argument('--num_steps', default=50, type=int, dest="num_steps", help='number of denoising steps (default: 50)')
    parser.add_argument('--sample_prompts', default=None, type=int, dest="sample_prompts", help='number of prompts to sample at random(default: None)')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()