import os
import random

import torch
from PIL import Image
from torch import autocast
from tqdm.auto import tqdm
from utils.stable_diffusion import generate_images
from utils.stable_diffusion import load_sd_components, load_text_components
import argparse
from utils.datasets import load_prompts
from rtpt import RTPT
import pandas as pd
import json
import re
import sys
from random import sample 
from utils.adv_embedding import find_adv_text_embeddings
from hooks.block_activations import RescaleLinearActivations
import glob

def str_to_list(s):
    pattern = re.compile(r'\[.*?\]')
    sublists = pattern.findall(s)
    return [list(map(int, re.findall(r'\d+', sublist))) for sublist in sublists]

def set_hooks(unet, blocked_indices, scaling_factor):
    block_handles = []
    for down_block in range(3):
        for attention in range(2):
            indices = blocked_indices[down_block * 2 + attention]
            block_hook = RescaleLinearActivations(indices=indices, factor=scaling_factor)
            block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
            block_handles.append(block_handle)
    block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=scaling_factor)
    block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)
    block_handles.append(block_handle)
    return block_handles

def main():
    args = create_parser()

    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)
    vae.requires_grad_(False)
    text_encoder.to(torch_device)
    text_encoder.requires_grad_(False)
    unet.to(torch_device)
    unet.requires_grad_(False)
    
    os.makedirs(args.output_path, exist_ok=False)
    
    with open(os.path.join(args.output_path, "config.json"), "w") as outfile:
        args_to_save = vars(args)
        args_to_save['command'] = " ".join(sys.argv)
        json.dump(args_to_save, outfile)
        
    # load csv file
    df = pd.read_csv(args.result_file, sep=';')
    
    rtpt = RTPT(args.user, 'image generation', len(df))
    rtpt.start()
        
    for i in tqdm(range(len(df)), total=len(df)):
        rows = df.iloc[i:i+1]
        prompts = rows['Caption'].to_list()
        
        # block memorization neurons
        blocked_indices = str_to_list(rows.iloc[0]['Refined Neurons'])
        handles = set_hooks(unet, blocked_indices, args.scaling_factor)
                
        # find adv embeddings
        if args.use_training_images:
            img_path = glob.glob(f'images_old/memorized_images/{i:04d}*.png')[0]
        else:
            img_path = f'images_old/original_images_v1_4_10_samples_per_prompt/img_{i:04d}_00.jpg'
        try:
            adv_embeddings = find_adv_text_embeddings(img_path, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, prompt=prompts, num_steps=args.adv_steps, batch_size=args.adv_batch_size, seed=args.seed, lr=args.adv_learning_rate)
            images = generate_images(None, tokenizer, text_encoder, vae, unet, scheduler, text_embeddings=adv_embeddings, num_inference_steps=args.num_steps, guidance_scale=args.guidance_scale, samples_per_prompt=args.num_samples, seed=args.seed)

        except Exception as e:
            print(f"Error generating images for prompt {i}: {e}")
            for handle in handles:
                handle.remove()
            continue
        
        # remove block handles
        for handle in handles:
            handle.remove()

        for j in range(len(images)):
            images[j].save(f"{args.output_path}/img_{i + j // args.num_samples:04d}_{j%args.num_samples:02d}.jpg")
        rtpt.step()

def create_parser():
    parser = argparse.ArgumentParser(description='Generating images')
    parser.add_argument(
        '-f',
        '--result_file',
        default='results/memorization_statistics_v1_4.csv',
        type=str,
        dest="result_file",
        help='path to file with image descriptions (default: results/memorization_statistics_v1_4.csv)')
    parser.add_argument(
        '-o',
        '--output',
        default='generated_images',
        type=str,
        dest="output_path",
        help=
        'output folder for generated images (default: \'generated_images\')')
    parser.add_argument('-s',
                        '--seed',
                        default=2,
                        type=int,
                        dest="seed",
                        help='seed for generated images (default: 2')
    parser.add_argument(
        '-n',
        '--num_samples',
        default=10,
        type=int,
        dest="num_samples",
        help='number of generated samples for each prompt (default: 10)')
    parser.add_argument('--steps',
                        default=50,
                        type=int,
                        dest="num_steps",
                        help='number of denoising steps (default: 50)')
    parser.add_argument('-g',
                        '--guidance_scale',
                        default=7,
                        type=float,
                        dest="guidance_scale",
                        help='guidance scale (default: 7)')
    parser.add_argument('-u',
                        '--user',
                        default='XX',
                        type=str,
                        dest="user",
                        help='name initials for RTPT (default: "XX")')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('--scaling_factor', default=0, type=float, help='Scaling factor for the blocking of neurons')

    parser.add_argument('--adv_learning_rate', default=0.1, type=float, help='learning rate for adv embedding optimization')
    parser.add_argument('--adv_steps', default=50, type=int, help='Number of optimization steps for adv embedding')
    parser.add_argument('--adv_batch_size', default=8, type=int, help='Batch size for adv embedding optimization')
    parser.add_argument('--use_training_images', action='store_true', help='USe training images for adv embedding optimization')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()