from rtpt import RTPT
import argparse
import pandas as pd
import pickle
import torch
from functools import partial

from utils.stable_diffusion import load_sd_components, load_text_components
from utils.wanda import get_input_norms, get_masking_matrices, set_wanda_blocking_hooks

def main():
    args = create_parser()

    # load prompts
    prompts = pd.read_csv(args.prompts, sep=';')

    # Load SD components
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)    
    vae.eval()
    text_encoder.to(torch_device)
    text_encoder.eval()
    unet.to(torch_device)
    unet.eval()

    blocks = []
    if args.layers == 'all':
        blocks = [True] * 16
    else:
        max_block_idx = int(args.layers)
        blocks = [True] * max_block_idx + [False] * (16 - max_block_idx)

    rtpt = RTPT(args.user, 'Wanda_Input_Norms_Iter', len(prompts) * sum(blocks))

    uncond_input_norms = {}
    cond_input_norms = {}
    for i in range(len(blocks)):
        if not blocks[i]:
            continue

        current_block_list = [False] * 16
        current_block_list[i] = True

        print(f'Calculating input norms for block {i + 1} of {sum(blocks)}')
        
        uncond_input_norms_curr_block, cond_input_norms_curr_block = get_input_norms(
            prompts=prompts['Caption'].tolist(),
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            unet=unet,
            scheduler=scheduler,
            guidance_scale=7.0,
            seed=args.seed,
            samples_per_prompt=4,
            num_inference_steps=50,
            blocks=current_block_list,
            rtpt=rtpt,
            early_stopping=args.timesteps_used,
        )
        uncond_input_norms.update(uncond_input_norms_curr_block)
        cond_input_norms.update(cond_input_norms_curr_block)

        masking_matrices = get_masking_matrices(
            unet, 
            uncond_input_norms_curr_block, 
            cond_input_norms_curr_block, 
            percentage_of_neurons_to_prune=args.sparsity,
            timesteps_used=args.timesteps_used 
        )
        _ = set_wanda_blocking_hooks(
            unet=unet, 
            binary_masks=masking_matrices,
        )

    with open(f'{args.output.split(".")[0]}_iter_sparsity_{args.sparsity}_timesteps_{args.timesteps_used}.pkl', 'wb') as f:
        pickle.dump((uncond_input_norms, cond_input_norms), f)

def create_parser():
    parser = argparse.ArgumentParser(description='Calculating Wanda Input Norms')
    
    parser.add_argument('--prompts', default='prompts/memorized_laion_prompts.csv', type=str, help='The file from which the prompts are loaded to calculate the statistics (default: \'prompts/additional_laion_prompts.csv\').')
    parser.add_argument('--output', default='wanda/iter_input_norms.pkl', type=str, help='The file to which the input norms are written (default: \'wanda/input_norms.pkl\').')
    parser.add_argument('--layers', default='all', type=str, help='Up to which layer we are getting the layer norms. (default: all layers)')
    parser.add_argument('--sparsity', default=0.015, type=float, help='The percentage of neurons to prune (default: 0.01).')
    parser.add_argument('--timesteps_used', default=1, type=int, help='The number of timesteps used for the masking matrices (default: 50).')

    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('-u',
                    '--user',
                    default='XX',
                    type=str,
                    dest="user",
                    help='name initials for RTPT (default: "XX")')
    parser.add_argument('-s', '--seed', default=1, type=int, dest="seed", help='seed for random number generator (default: 1)')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()