import torch
import torchvision
from torchvision import transforms
from rtpt import RTPT
import torch
from PIL import Image
from utils.stable_diffusion import load_sd_components, load_text_components, compute_text_embedding, generate_images
from utils.adv_embedding import find_adv_text_embeddings
from tqdm import tqdm
import torch.nn.functional as F
import copy
from torchvision.datasets import CocoCaptions, ImageFolder
from utils.datasets import CaptionDataset, get_image_transforms
import argparse
import os
import sys
import json
from itertools import cycle
import random
from accelerate import Accelerator
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

torch.set_num_threads(8)

def main():
    args = create_parser()
    
    # create output directory
    os.makedirs(args.output_path, exist_ok=False)
    os.makedirs(os.path.join(args.output_path, 'checkpoints'), exist_ok=False)
    os.makedirs(os.path.join(args.output_path, 'image_samples'), exist_ok=False)
    with open(os.path.join(args.output_path, "config.json"), "w") as outfile:
        args_to_save = vars(args)
        args_to_save['command'] = " ".join(sys.argv)
        json.dump(args_to_save, outfile)
            
    # load dataset        
    dataset_mem = CaptionDataset(args.img_path_memorized, transform=None, load_captions=True, return_img_path=True)
    dataset_non_mem = CaptionDataset(args.img_path_non_memorized, get_image_transforms(hflip=True), load_captions=True)
    dataset_mitigation = CaptionDataset(args.img_path_mitigated, get_image_transforms(hflip=True), load_captions=False)
    
    ratio = len(dataset_mitigation) // len(dataset_mem)
    dataloader_mem = torch.utils.data.DataLoader(dataset_mem, batch_size=1, shuffle=True, num_workers=1, drop_last=False)
    dataloader_iter_non_mem = iter(torch.utils.data.DataLoader(dataset_non_mem, batch_size=args.batch_size_non_mem, shuffle=True, num_workers=8, drop_last=True))
    
    assert(len(dataset_mitigation) % len(dataset_mem) == 0)
    sample_ratio = len(dataset_mitigation) // len(dataset_mem)
    print(f"Sample ratio: {sample_ratio}")

    # load Stable Diffusion components
    vae, unet, scheduler = load_sd_components(args.version)
    tokenizer, text_encoder = load_text_components(args.version)

    torch_device = "cuda"
    vae.to(torch_device)
    vae.requires_grad_(False)
    text_encoder.to(torch_device)
    text_encoder.requires_grad_(False)
    unet.to(torch_device)

    # define optimizer
    optimizer = torch.optim.AdamW(unet.parameters(), lr=args.lr)

    # initialize RTPT
    rtpt = RTPT(args.user, 'adv. finetuning', args.epochs)
    rtpt.start()
        
    # generate samples for non-memorized prompts
    evaluate_memozation(unet, tokenizer, text_encoder, vae, scheduler, 0, args)
    
    for epoch in range(args.epochs):        
        running_total_loss = 0
        running_loss_mitigation = 0
        running_loss_utility = 0
        
        for index, (mem_img_path, mem_prompt) in enumerate(dataloader_mem):
            mem_img_path = mem_img_path[0]
            mem_prompt = mem_prompt[0]
            sample_idx = int(mem_img_path.split('/')[-1].split('.')[0])
            mitigation_indices = [i for i in range(sample_idx * sample_ratio, (sample_idx + 1) * sample_ratio)]
            subset_mitigation = torch.utils.data.Subset(dataset_mitigation, mitigation_indices)
            dataloader_iter_mitigation = iter(torch.utils.data.DataLoader(subset_mitigation, batch_size=args.batch_size_non_mem, shuffle=True, num_workers=8, drop_last=True))

            adv_seed = epoch * len(dataset_mem) + index
            unet.eval()
            unet.requires_grad_(False)
            if epoch % 2 == 0:
                text_embedding_unlearning = find_adv_text_embeddings(mem_img_path, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, prompt=mem_prompt, num_steps=args.adv_steps, batch_size=args.adv_bs, seed=adv_seed, lr=args.adv_learning_rate, fp16=args.fp16)
            else:
                text_embedding_unlearning = find_adv_text_embeddings(mem_img_path, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, prompt=None, num_steps=args.adv_steps, batch_size=args.adv_bs, seed=adv_seed, lr=args.adv_learning_rate, fp16=args.fp16)
                
            unet.train()
            unet.requires_grad_(True)
            for step in range(args.steps):
                # draw memorized samples
                try:
                    imgs_mitigation = next(dataloader_iter_mitigation)
                except StopIteration:
                    dataloader_iter_mitigation = iter(torch.utils.data.DataLoader(subset_mitigation, batch_size=args.batch_size_mitigation, shuffle=True, num_workers=8))
                    imgs_mitigation = next(dataloader_iter_mitigation)
                imgs_mitigation = imgs_mitigation.to(torch_device)
                
                # draw non-memorized samples
                try:
                    imgs_non_mem, prompts_non_mem = next(dataloader_iter_non_mem)
                except Exception:
                    dataloader_iter_non_mem = iter(torch.utils.data.DataLoader(dataset_non_mem, batch_size=args.batch_size_non_mem, shuffle=True, num_workers=8))
                    imgs_non_mem, prompts_non_mem = next(dataloader_iter_non_mem)

                imgs_non_mem = imgs_non_mem.to(torch_device)
                
                with torch.autocast(device_type=torch_device, dtype=torch.bfloat16 if args.fp16 else torch.float32):
                    # encode images into latents
                    batch_images = torch.cat([imgs_mitigation, imgs_non_mem], dim=0)                
                    batch_images_latents = vae.encode(batch_images).latent_dist.sample()
                    batch_images_latents = batch_images_latents * vae.config.scaling_factor
                    
                    # encode text into latents
                    text_embeddings_non_mem = compute_text_embedding(prompts_non_mem, tokenizer, text_encoder)
                    text_embedding_unlearning_rep = torch.repeat_interleave(text_embedding_unlearning, dim=0, repeats=args.batch_size_mitigation)
                    batch_text_embeddings = torch.cat([text_embedding_unlearning_rep, text_embeddings_non_mem], dim=0)
                    
                    # sample noise and timesteps
                    noise = torch.randn_like(batch_images_latents)
                    timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_images_latents.shape[0],), device=batch_images_latents.device)
                    timesteps = timesteps.long()
                    noisy_latents = scheduler.add_noise(batch_images_latents, noise, timesteps)
                    
                    model_pred = unet(noisy_latents, timesteps, batch_text_embeddings, return_dict=False)[0]
                
                    loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction='none')
                    loss_mitigation = loss[:args.batch_size_mitigation].mean()
                    loss_utility = loss[args.batch_size_mitigation:].mean()
                    loss = loss.mean()
                    
                loss.backward()
            
                optimizer.step()
                optimizer.zero_grad()
                
                running_total_loss += loss.item()
                running_loss_mitigation += loss_mitigation.item()
                running_loss_utility += loss_utility.item()
            if (index+1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{args.epochs}, Step {index+1}/{len(dataset_mem)}, Loss: {loss.item()/(index+1):.4f}, Loss Mitigation: {loss_mitigation.item()/(index+1):.4f}, Loss Utility: {loss_utility.item()/(index+1):.4f}")

        # create dataloader        
        rtpt.step(f"Epoch {epoch + 1} of {args.epochs}")
        
        # save model
        unet.eval()
        unet.requires_grad_(False)

        if (epoch + 1) % args.checkpoint_interval == 0:
            checkpoint_path = os.path.join(args.output_path, 'checkpoints', f"checkpoint_epoch_{epoch + 1}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': unet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': running_total_loss / len(dataset_mem),
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
            
        evaluate_memozation(unet, tokenizer, text_encoder, vae, scheduler, epoch+1, args)

def evaluate_memozation(unet, tokenizer, text_encoder, vae, scheduler, epoch, args):
    unet.eval()
    unet.requires_grad_(False)

    dataset_eval = CaptionDataset(args.img_path_eval, transform=None, load_captions=True, return_img_path=True)

    # adv. embedding search starting from memorized prompt 
    for idx, (img_path, prompt) in enumerate(dataset_eval):
        text_embedding_unlearning = find_adv_text_embeddings(img_path, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, prompt=prompt, num_steps=2*args.adv_steps, batch_size=args.adv_bs, seed=1, lr=args.adv_learning_rate, fp16=args.fp16)
        images = generate_images(prompts=None, text_embeddings=text_embedding_unlearning, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, samples_per_prompt=3, guidance_scale=7)
        for i, image in enumerate(images):
            image.save(os.path.join(args.output_path, 'image_samples', f"memorized_prompt_{epoch}_{idx}_{i}.jpg"))

    
    # adv. embedding search starting from random embeddings 
    for idx, (img_path, prompt) in enumerate(dataset_eval):
        text_embedding_unlearning = find_adv_text_embeddings(img_path, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, prompt=None, num_steps=2*args.adv_steps, batch_size=args.adv_bs, seed=1, lr=args.adv_learning_rate, fp16=args.fp16)
        images = generate_images(prompts=None, text_embeddings=text_embedding_unlearning, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, samples_per_prompt=3, guidance_scale=7)
        for i, image in enumerate(images):
            image.save(os.path.join(args.output_path, 'image_samples', f"memorized_random_{epoch}_{idx}_{i}.jpg"))

    # adv. embedding search starting from random embeddings 
    for idx, (img_path, _) in enumerate(dataset_eval):
        prompt = 'a photo of a cat'
        text_embedding_unlearning = find_adv_text_embeddings(img_path, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, prompt=prompt, num_steps=2*args.adv_steps, batch_size=args.adv_bs, seed=1, lr=args.adv_learning_rate, fp16=args.fp16)
        images = generate_images(prompts=None, text_embeddings=text_embedding_unlearning, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, samples_per_prompt=3, guidance_scale=7)
        for i, image in enumerate(images):
            image.save(os.path.join(args.output_path, 'image_samples', f"memorized_unrelated_{epoch}_{idx}_{i}.jpg"))

    # generate samples for non-memorized prompts
    eval_prompts_non_mem = ["A portrait of a cyborg in a golden suit, D&D sci-fi, artstation, concept art, highly detailed illustration",
                    'A photo of a cute cat',
                    'A photo of New York City skyline at night',
                    "Sony Won't Release <i>The Interview</i> on VOD"]
    
    # generate samples for non-memorized prompts
    for idx, eval_prompt in enumerate(eval_prompts_non_mem):
        images = generate_images(prompts=[eval_prompt], unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, samples_per_prompt=3, guidance_scale=7)
        for i, image in enumerate(images):
            image.save(os.path.join(args.output_path, 'image_samples', f"non_memorized_{epoch}_{idx}_{i}.jpg"))


def draw_samples(dataloader, num_samples):
    """
    Draws a number of samples from the dataset.
    """
    indices = random.sample(range(len(dataset)), num_samples)
    samples = [dataset[i] for i in indices]
    return samples


def create_parser():
    parser = argparse.ArgumentParser(description='Generating images')
    parser.add_argument('-s',
                        '--seed',
                        default=1,
                        type=int,
                        dest="seed",
                        help='seed for training (default: 1')
    parser.add_argument('-o',
                        '--output_path',
                        type=str,
                        help='path to output folder')
    parser.add_argument('-u',
                        '--user',
                        default='XX',
                        type=str,
                        dest="user",
                        help='name initials for RTPT (default: "XX")')
    parser.add_argument('-v',
                        '--version',
                        default='v1-4',
                        type=str,
                        dest="version",
                        help='Stable Diffusion version (default: "v1-4")')
    parser.add_argument('-lr',
                        '--learning_rate',
                        default=1e-5,
                        type=float,
                        dest="lr",
                        help='learning rate for training (default: 1e-5)')
    parser.add_argument('--batch_size_non_mem',
                        default=4,
                        type=int,
                        dest="batch_size_non_mem",
                        help='batch size for clean images (default: 4)')
    parser.add_argument('--batch_size_mitigation',
                        default=4,
                        type=int,
                        dest="batch_size_mitigation",
                        help='batch size for adv images (default: 4)')
    parser.add_argument('-e',
                        '--epochs',
                        default=1,
                        type=int,
                        dest="epochs",
                        help='number of epochs for training (default: 1)')
    parser.add_argument('--steps',
                        default=1,
                        type=int,
                        dest="steps",
                        help='number of optimization steps per embedding(default: 1)')
    parser.add_argument('-adv_steps',
                        '--adv_steps',
                        default=50,
                        type=int,
                        dest="adv_steps",
                        help='number of steps for adv training (default: 50)')
    parser.add_argument('-adv_lr',
                        '--adv_learning_rate',
                        default=1e-1,
                        type=float,
                        dest="adv_learning_rate",
                        help='learning rate for adv training (default: 1e-1)')
    parser.add_argument('-adv_bs',
                        '--adv_batch_size',
                        default=8,
                        type=int,
                        dest="adv_bs",
                        help='batch size for adv training (default: 8)')
    parser.add_argument('--img_path_memorized',
                        type=str,
                        dest="img_path_memorized",
                        help='path to the folder with memorized training images')
    parser.add_argument('--img_path_non_memorized',
                        type=str,
                        dest="img_path_non_memorized",
                        help='path to the folder with non-memorized training images')
    parser.add_argument('--img_path_mitigated',
                        type=str,
                        dest="img_path_mitigated",
                        help='path to the folder with generated images after mitigation')
    parser.add_argument('--img_path_eval',
                        type=str,
                        dest="img_path_eval",
                        help='path to the folder with memorized images for evaluation')
    parser.add_argument('--fp16',
                        type=bool,
                        default=False,
                        dest="fp16",
                        help='use mixed precision for training (default: True)')
    parser.add_argument('--grad_acc_steps',
                        type=int,
                        default=1,
                        dest="grad_acc_steps",
                        help='number of gradient accumulation steps (default: 1)')
    parser.add_argument('--checkpoint_interval',
                        type=int,
                        default=1,
                        dest="checkpoint_interval",
                        help='interval for saving checkpoints (default: 1)')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()