import os
import time
import torch
import itertools
from PIL import Image
from pathlib import Path
from torchvision import transforms
from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
from PIL.ImageOps import exif_transpose
from diffusers.optimization import get_scheduler
from bat_dreambooth import BatDreamBoothDataset, collate_fn

#set seed for reproducibility
import random
import numpy as np

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

from accelerate import Accelerator
import torch.nn.functional as F

accelerator = Accelerator()

import argparse

parser = argparse.ArgumentParser(description="Model Training Configuration")

parser.add_argument("--pretrained_model_name_or_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the pretrained model or model name")
parser.add_argument("--revision", type=str, default=None, help="Model revision (optional)")
parser.add_argument("--variant", type=str, default=None, help="Model variant (optional)")
parser.add_argument("--instance_path", type=str, default="/models/dog/dog-data/", help="Path to the instance data")
parser.add_argument("--instance_prompt", type=str, default="a photo of sks dog", help="Prompt for the instance data")
parser.add_argument("--class_path", type=str, default="/dog/dog-class-data/", help="Path to the class data")
parser.add_argument("--class_prompt", type=str, default="a photo of dog", help="Prompt for the class data")
parser.add_argument("--bat_path", type=str, default="/laion_dataset_text/metadata.json", help="Path to the bat dataset metadata")
parser.add_argument("--bat_ratio", type=float, default=0.16, help="Ratio for the bat dataset")
parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate for the optimizer")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam optimizer beta1")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="Adam optimizer beta2")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay for Adam optimizer")
parser.add_argument("--adam_epsilon", type=float, default=1e-8, help="Epsilon for Adam optimizer")
parser.add_argument("--num_train_epochs", type=float, default=0, help="Number of training epochs")
parser.add_argument("--max_train_steps", type=int, default=None, help="Maximum number of training steps")
parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of gradient accumulation steps")
parser.add_argument("--total_batch_size", type=int, default=None, help="Total batch size (calculated)")
parser.add_argument("--lr_scheduler_config", type=str, default="constant", help="Learning rate scheduler configuration")
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of warmup steps for learning rate scheduler")
parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of cycles for learning rate scheduler")
parser.add_argument("--lr_power", type=int, default=1, help="Power parameter for learning rate scheduler")
parser.add_argument("--bat", action="store_true", help="Whether to calculate Hessian")
parser.add_argument("--optimal_model_path", type=str, default="/scratch2/paneah/bat/models/dog/09242024", help="Path to the optimal model")
parser.add_argument("--output_dir", type=str, default="/scratch2/paneah/bat/models/dog/09242024", help="Output directory")
parser.add_argument("--save", action="store_true", help="Whether to save the model")
parser.add_argument("--saving_step", type=int, default=50, help="Power parameter for learning rate scheduler")
parser.add_argument("--result_name", type=str, default="", help="Name for the result file")
parser.add_argument("--data_root_size", type=int, default=5, help="Size of the data root")
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility")



args = parser.parse_args()

# Set seed at the beginning of the script
# All mains are done with 42
set_seed(args.seed)

pretrained_model_name_or_path = args.pretrained_model_name_or_path
revision = args.revision
variant = args.variant
instance_path = args.instance_path
instance_prompt = args.instance_prompt
class_path = args.class_path
class_prompt = args.class_prompt
bat_path = args.bat_path
bat_ratio = args.bat_ratio
learning_rate = args.learning_rate
adam_beta1 = args.adam_beta1
adam_beta2 = args.adam_beta2
adam_weight_decay = args.adam_weight_decay
adam_epsilon = args.adam_epsilon
num_train_epochs = args.num_train_epochs
train_batch_size = args.train_batch_size
max_train_steps = args.max_train_steps
data_root_size = args.data_root_size
gradient_accumulation_steps = args.gradient_accumulation_steps
total_batch_size = args.total_batch_size
lr_scheduler_config = args.lr_scheduler_config
lr_warmup_steps = args.lr_warmup_steps
lr_num_cycles = args.lr_num_cycles
lr_power = args.lr_power
bat = args.bat
optimal_model_path = args.optimal_model_path
output_dir = args.output_dir
save = args.save
saving_step = args.saving_step
result_name = output_dir + "/" + optimal_model_path.split("/")[-2] + "_" + optimal_model_path.split("/")[-1] + "_" + args.result_name
#make sure the output directory exists
if not os.path.exists(result_name):
    os.makedirs(result_name)


def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=None,
            use_fast=False,
)

# Load the dataset with backbone data
dataset = BatDreamBoothDataset(instance_path,
                            tokenizer=tokenizer,
                            instance_prompt=instance_prompt,
                            class_prompt=class_prompt,
                            class_data_root=class_path,
                            bat_data_root=bat_path,
                            bat_ratio=bat_ratio,
                            bat_data_root_size=data_root_size,
)  

#forward pass 
from huggingface_hub import model_info
from diffusers.utils.torch_utils import is_compiled_module
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DiffusionPipeline

# Ensure deterministic behavior in DataLoader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(dataset, 
                                        batch_size=1,
                                        collate_fn=lambda x: collate_fn(x, with_prior_preservation=True), 
                                        shuffle=False, 
                                        worker_init_fn=seed_worker
)

if num_train_epochs != 0:
    max_train_steps = int(len(dataloader) * num_train_epochs)

#get vae
def model_has_vae(pretrained_model_name_or_path, revision):
    config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
    if os.path.isdir(pretrained_model_name_or_path):
        config_file_name = os.path.join(pretrained_model_name_or_path, config_file_name)
        return os.path.isfile(config_file_name) 
    else:
        files_in_repo = model_info(pretrained_model_name_or_path, revision=revision).siblings
        return any(file.rfilename == config_file_name for file in files_in_repo)

if model_has_vae(pretrained_model_name_or_path, revision):
    vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant
    )
else:
    vae = None

def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    elif model_class == "T5EncoderModel":
        from transformers import T5EncoderModel

        return T5EncoderModel
    else:
        raise ValueError(f"{model_class} is not supported.")

#get text encoder
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision)
text_encoder = text_encoder_cls.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, variant=variant
)

#get unet
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet", revision=revision, variant=variant
)

unet_copy = {key: value.clone().detach() for key, value in unet.state_dict().items()}

optimizer_class = torch.optim.AdamW

params_to_optimize = (
    itertools.chain(unet.parameters())
)

optimizer = optimizer_class(
    params_to_optimize,
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

weight_dtype = torch.float32
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")

def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
        return_dict=False,
    )
    prompt_embeds = prompt_embeds[0]

    return prompt_embeds

#get scheduler 
lr_scheduler = get_scheduler(
    lr_scheduler_config,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * accelerator.num_processes,
    num_training_steps=max_train_steps * accelerator.num_processes,
    num_cycles=lr_num_cycles,
    power=lr_power,
)

unet, text_encoder, optimizer, dataloader, lr_scheduler = accelerator.prepare(
            unet, text_encoder, optimizer, dataloader, lr_scheduler
)

def unwrap_model(model):
    model = accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

initial_global_step = 0

from tqdm.auto import tqdm

progress_bar = tqdm(
    range(0, max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not True,
)

start = time.time()

from accelerate.logging import get_logger
logger = get_logger(__name__)

hes_start = time.time()

def print_memory_stats(step):
    allocated_memory = torch.cuda.memory_allocated()
    reserved_memory = torch.cuda.memory_reserved()
    print(f"Step {step} - Allocated Memory: {allocated_memory / (1024 ** 2):.2f} MB")
    print(f"Step {step} - Reserved Memory: {reserved_memory / (1024 ** 2):.2f} MB")

#calculate hessians
if bat:
    text_encoder.to("cpu")
    #initial comparison of Hessian
    #import the assumed optimal model 
    our_unet = UNet2DConditionModel.from_pretrained(optimal_model_path+ "/unet")
    original_unet_params = [p.clone().to(accelerator.device) for p in our_unet.parameters()]

    del our_unet

    # Compute the norm difference between original and updated parameters for both models
    unet_differences = []
    for original_param, updated_param in zip(original_unet_params, unet.parameters()):
        unet_differences.append((original_param - updated_param).view(-1))  # Flatten each tensor before concatenating

    del original_unet_params
    torch.cuda.empty_cache()

    # Concatenate all differences into one tensor
    unet_differences_concat = torch.cat(unet_differences)

    del unet_differences
    torch.cuda.empty_cache()

    # Compute the L2 norm of the concatenated tensor
    unet_norm_diff_init = torch.norm(unet_differences_concat).item() / dataset.num_instance_images
    
    del unet_differences_concat
    torch.cuda.empty_cache()
    
    print(f"L2 Norm Difference for Initial UNet parameters: {unet_norm_diff_init}")

    hes_end = time.time()

    hes_total = hes_end - hes_start

    print(f"it took {hes_total} seconds to calculate hessian norm")

    hessian_step = 0 
    saved_models = 0

    #prepare the progress bar

    progress_bar = tqdm(
        range(0, max_train_steps),
        initial=hessian_step,
        desc="G Steps",
        # Only show the progress bar once on each machine.
        disable=not True,
    )

    unet_hessian = []
    unet_hessian.append(unet_norm_diff_init)

    #since our model has vae
    if vae is not None:
        vae.requires_grad_(False)
        vae.to(accelerator.device, dtype=weight_dtype)

    # Move the model to the GPU
    text_encoder.to(accelerator.device)

    while True:
        for idx, batch in enumerate(dataloader):
            if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) < dataset.num_instance_images:
                loss = 0
                model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()

                model_input = model_input * vae.config.scaling_factor
                noise = torch.randn_like(model_input)

                bsz, channels, height, width = model_input.shape
                #sample a random timestep for each image
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
                )

                timesteps = timesteps.long()
                # Add noise to the model input according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

                encoder_hidden_states = encode_prompt(
                                    text_encoder,
                                    batch["input_ids"],
                                    batch["attention_mask"],
                                    text_encoder_use_attention_mask=False,
                                )

                if unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
                class_labels = None

                model_pred = unet(
                                noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
                            )[0]

                if model_pred.shape[1] == 6:
                                model_pred, _ = torch.chunk(model_pred, 2, dim=1)
                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute prior loss
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                # Compute model loss
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                loss = loss + prior_loss

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(unet.parameters())
                    )
                    accelerator.clip_grad_norm_(params_to_clip, 1)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)

                    #check training 
                    if accelerator.sync_gradients:
                        progress_bar.update(accelerator.num_processes)
                        hessian_step += accelerator.num_processes

                        # Calculate the hessian for G 

                        if hessian_step % dataset.num_instance_images == 0:

                            # Move vae to the CPU
                            vae.to("cpu")
                            text_encoder.to("cpu")

                            #import the assumed optimal model 
                            our_unet = UNet2DConditionModel.from_pretrained(optimal_model_path + "/unet")
                            original_unet_params = [p.clone().to(accelerator.device) for p in our_unet.parameters()]

                            del our_unet

                            # Compute the norm difference between original and updated parameters for both models
                            unet_differences = []
                            for original_param, updated_param in zip(original_unet_params, unet.parameters()):
                                unet_differences.append((original_param - updated_param).view(-1))  # Flatten each tensor before concatenating

                            del original_unet_params
                            torch.cuda.empty_cache()

                            # Concatenate all differences into one tensor
                            unet_differences_concat = torch.cat(unet_differences)

                            del unet_differences
                            torch.cuda.empty_cache()

                            # Compute the L2 norm of the concatenated tensor
                            unet_norm_diff_init = torch.norm(unet_differences_concat).item() / dataset.num_instance_images
                            unet_hessian.append(unet_norm_diff_init)

                            del unet_differences_concat
                            unet.to("cpu")
                            torch.cuda.empty_cache()

                            # Move the original model back to the GPU
                            text_encoder.to(accelerator.device)
                            vae.to(accelerator.device)
                            unet.to(accelerator.device)
                        
                        if hessian_step != 0 and hessian_step % saving_step == 0 and save:
                            accelerator.wait_for_everyone()
                            if accelerator.is_main_process:
                                pipeline_args = {}

                                pipeline = DiffusionPipeline.from_pretrained(
                                    pretrained_model_name_or_path,
                                    unet=unwrap_model(unet),
                                    revision=revision,
                                    variant=variant,
                                    **pipeline_args,
                                )

                                # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                scheduler_args = {}

                                if "variance_type" in pipeline.scheduler.config:
                                    variance_type = pipeline.scheduler.config.variance_type

                                    if variance_type in ["learned", "learned_range"]:
                                        variance_type = "fixed_small"

                                    scheduler_args["variance_type"] = variance_type

                                pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                pipeline.save_pretrained(result_name + f"/adapter_model_{hessian_step}")
                                saved_models += 1


                    logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=hessian_step)

            if hessian_step >= max_train_steps:
                break

        if hessian_step >= max_train_steps:
            break

if bat:
    import copy 
    unet_hessian_copy = copy.deepcopy(unet_hessian)
    print(f"  unet_hessian for G set {unet_hessian[0:5]}...")
    # Restore the model parameters
    unet.load_state_dict(unet_copy)
    bat_unet_hessian = []
    bat_unet_hessian.append(unet_hessian[0])

if not bat:
    #move vae to gpu
    vae.requires_grad_(False)
    vae.to(accelerator.device, dtype=weight_dtype)


print("***** Running training *****")
print(f"  Num examples = {len(dataset)}")
print(f"  Num batches each epoch = {len(dataloader)}")
print(f"  Num Epochs = {num_train_epochs}")
print(f"  Instantaneous batch size per device = {train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f"  Total optimization steps = {max_train_steps}")

global_step = 0
saved_models = 0

# Reset the progress bar
progress_bar = tqdm(
    range(0, max_train_steps),
    initial=global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not True,
)

if bat:
    good_data = []
    bad_data = []

while True:
        print("Resuming training")
        for idx, batch in enumerate(dataloader):
            if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) < dataset.num_instance_images: 
                loss = 0
                model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()

                model_input = model_input * vae.config.scaling_factor
                noise = torch.randn_like(model_input)

                bsz, channels, height, width = model_input.shape
                #sample a random timestep for each image
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
                )

                timesteps = timesteps.long()
                # Add noise to the model input according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

                encoder_hidden_states = encode_prompt(
                                    text_encoder,
                                    batch["input_ids"],
                                    batch["attention_mask"],
                                    text_encoder_use_attention_mask=False,
                                )

                if unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

                class_labels = None

                model_pred = unet(
                                noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
                            )[0]

                if model_pred.shape[1] == 6:
                                model_pred, _ = torch.chunk(model_pred, 2, dim=1)
                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
        
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute prior loss
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                # Compute model loss
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                loss = loss + prior_loss

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(unet.parameters())
                    )
                    accelerator.clip_grad_norm_(params_to_clip, 1)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)

                    #check training 
                    if accelerator.sync_gradients:
                        progress_bar.update(accelerator.num_processes)
                        global_step += accelerator.num_processes

                        if global_step != 0 and global_step % saving_step == 0 and save and bat:
                            accelerator.wait_for_everyone()
                            if accelerator.is_main_process:
                                pipeline_args = {}

                                pipeline = DiffusionPipeline.from_pretrained(
                                    pretrained_model_name_or_path,
                                    unet=unwrap_model(unet),
                                    revision=revision,
                                    variant=variant,
                                    **pipeline_args,
                                )
                                # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                scheduler_args = {}

                                if "variance_type" in pipeline.scheduler.config:
                                    variance_type = pipeline.scheduler.config.variance_type

                                    if variance_type in ["learned", "learned_range"]:
                                        variance_type = "fixed_small"

                                    scheduler_args["variance_type"] = variance_type

                                pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                pipeline.save_pretrained(result_name + f"/bat_model_{global_step}")
                                saved_models += 1 

                    logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=global_step)

                    
                    #check training
                    if global_step % dataset.num_instance_images == 0 and global_step != 0 and bat:
                        #save the model
                        unet.to("cpu")

                        unet_copy = {key: value.clone().detach() for key, value in unet.state_dict().items()}

                        unet.to(accelerator.device)
                        
                        while True:
                            print("checking for bad data") 
                            #find data
                            added_data = 0
                            for idx, batch in enumerate(dataloader):
                                if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) >= dataset.num_instance_images:
                                    loss = 0
                                    model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                                    model_input = model_input * vae.config.scaling_factor
                                    noise = torch.randn_like(model_input)
                                    bsz, channels, height, width = model_input.shape
                                    #sample a random timestep for each image
                                    timesteps = torch.randint(
                                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
                                    )
                                    timesteps = timesteps.long()
                                    # Add noise to the model input according to the noise magnitude at each timestep
                                    noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
                                    encoder_hidden_states = encode_prompt(
                                                        text_encoder,
                                                        batch["input_ids"],
                                                        batch["attention_mask"],
                                                        text_encoder_use_attention_mask=False,
                                                    )
                                    if unwrap_model(unet).config.in_channels == channels * 2:
                                        noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
                                    class_labels = None

                                    model_pred = unet(
                                                    noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
                                                )[0]

                                    if model_pred.shape[1] == 6:
                                                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)
                                    # Get the target for loss depending on the prediction type
                                    if noise_scheduler.config.prediction_type == "epsilon":
                                        target = noise
                                    elif noise_scheduler.config.prediction_type == "v_prediction":
                                        target = noise_scheduler.get_velocity(model_input, noise, timesteps)
                                    else:
                                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
                                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                                    target, target_prior = torch.chunk(target, 2, dim=0)
                                    # Compute prior loss
                                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
                                    # Compute model loss
                                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                                    loss = loss + prior_loss

                                    accelerator.backward(loss)
                                    if accelerator.sync_gradients:
                                        params_to_clip = (
                                            itertools.chain(unet.parameters())
                                        )
                                        accelerator.clip_grad_norm_(params_to_clip, 1)
                                        optimizer.step()
                                        lr_scheduler.step()
                                        optimizer.zero_grad(set_to_none=True)
                                        added_data += 1
                                if added_data >= dataset.num_bat_images // 2:
                                    break
                            # Check the norm difference between the original and updated parameters

                            vae.to("cpu")
                            text_encoder.to("cpu")

                            torch.cuda.synchronize()
                            torch.cuda.empty_cache()

                            #initial comparison of Hessian
                            #import the assumed optimal model
                            our_unet = UNet2DConditionModel.from_pretrained(optimal_model_path + "/unet")
                            original_unet_params = [p.clone().to(accelerator.device) for p in our_unet.parameters()]

                            del our_unet

                            # Compute the norm difference between original and updated parameters for both models
                            unet_differences = []
                            for original_param, updated_param in zip(original_unet_params, unet.parameters()):
                                unet_differences.append((original_param - updated_param).view(-1))

                            del original_unet_params

                            # Concatenate all differences into one tensor
                            unet_differences_concat = torch.cat(unet_differences)

                            del unet_differences

                            # Compute the L2 norm of the concatenated tensor
                            new_unet_norm_diff = torch.norm(unet_differences_concat).item() / (dataset.num_instance_images + added_data)

                            del unet_differences_concat
                            unet.to("cpu")
                            torch.cuda.empty_cache()

                            unet_norm_diff = unet_hessian[0]

                            # Print norm differences
                            print(f"L2 Norm Difference for Updated UNet parameters with G: {unet_norm_diff}")

                            print(f"L2 Norm Difference for Updated UNet parameters with K: {new_unet_norm_diff}")

                            # Compare the norm differences
                            unet_norm_diff_ratio = new_unet_norm_diff - unet_norm_diff

                            # If the norm difference ratio is greater than 1, the model has diverged
                            if unet_norm_diff_ratio <= 0:
                                print("the added data has not caused the model to diverge")
                                print(f"UNet Norm Difference Ratio: {unet_norm_diff_ratio}")
                                bat_unet_hessian.append(new_unet_norm_diff)
                                #replace the original model with the new model
                                global_step += added_data
                                progress_bar.update(added_data)
                                #pop unet_hessian and text_encoder_hessian
                                unet_hessian.pop(0)
                                print(f"good data! {dataset.bat_prompts}")
                                good_data.append((f"step: {global_step}", dataset.bat_prompts))
                                dataset.give_new_bat(False)
                                msg = "Good"
                                #save the bat model                        
                                # Move the model back to the GPU
                                unet.to(accelerator.device)
                                vae.to(accelerator.device)
                                text_encoder.to(accelerator.device)
                                if global_step != 0 and global_step % saving_step == 0 and save:
                                    accelerator.wait_for_everyone()
                                    if accelerator.is_main_process:
                                        pipeline_args = {}

                                        pipeline = DiffusionPipeline.from_pretrained(
                                            pretrained_model_name_or_path,
                                            unet=unwrap_model(unet),
                                            revision=revision,
                                            variant=variant,
                                            **pipeline_args,
                                        )
                                        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                        scheduler_args = {}

                                        if "variance_type" in pipeline.scheduler.config:
                                            variance_type = pipeline.scheduler.config.variance_type

                                            if variance_type in ["learned", "learned_range"]:
                                                variance_type = "fixed_small"

                                            scheduler_args["variance_type"] = variance_type

                                        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                        pipeline.save_pretrained(result_name + f"/bat_model_{global_step}")
                                        saved_models += 1
                                        print("saving model")
                                break
                            else:
                                print("the added data has caused the model to diverge")
                                print(f"UNet Norm Difference Ratio: {unet_norm_diff_ratio}")
                                # Go back to the original model
                                unet.load_state_dict(unet_copy)
                                print(f"bad data! {dataset.bat_prompts}")
                                bad_data.append((f"step: {global_step}", dataset.bat_prompts))
                                was_bad = True
                                msg = dataset.give_new_bat(was_bad)
                                if msg == "No more bat images":
                                    break
                                elif isinstance(msg, list):
                                    print("New bat images added")
                                    print(f"New bat images: {msg}")

                            # Move the model back to the GPU
                            unet.to(accelerator.device)
                            vae.to(accelerator.device)
                            text_encoder.to(accelerator.device)

                        if dataset.good_bat:
                            break
            if global_step >= max_train_steps:
                break
        if global_step >= max_train_steps:
            break


end = time.time()
total = end - start

#plot hessian
if bat:
    import json
    unet_result = {"unet_hessian": unet_hessian_copy, "bat_unet_hessian": bat_unet_hessian}
    good_data_result = {"good_data": good_data}
    bad_data_result = {"bad_data": bad_data}

    with open(result_name + "/unet_hessian.json", "w") as outfile:
        json.dump(unet_result, outfile)
    with open(result_name + "/good_data.json", "w") as outfile:
        json.dump(good_data_result, outfile)
    with open(result_name + "/bad_data.json", "w") as outfile:
        json.dump(bad_data_result, outfile)
    

if save and not bat:
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        pipeline_args = {}

        pipeline = DiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path,
            unet=unwrap_model(unet),
            revision=revision,
            variant=variant,
            **pipeline_args,
        )

        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
        scheduler_args = {}

        if "variance_type" in pipeline.scheduler.config:
            variance_type = pipeline.scheduler.config.variance_type

            if variance_type in ["learned", "learned_range"]:
                variance_type = "fixed_small"

            scheduler_args["variance_type"] = variance_type

        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
        pipeline.save_pretrained(output_dir)


# Monitor VRAM usage after the calculations
final_allocated_memory = torch.cuda.memory_allocated()
final_reserved_memory = torch.cuda.memory_reserved()

print(f"Time taken: {end - start:.6f} seconds")

# Print VRAM usage after calculations
print(f"Final CUDA Memory Allocated: {final_allocated_memory / (1024 ** 2):.2f} MB")
print(f"Final CUDA Memory Reserved: {final_reserved_memory / (1024 ** 2):.2f} MB")

# Print peak memory usage
peak_allocated_memory = torch.cuda.max_memory_allocated()
peak_reserved_memory = torch.cuda.max_memory_reserved()

print(f"Peak CUDA Memory Allocated: {peak_allocated_memory / (1024 ** 2):.2f} MB")
print(f"Peak CUDA Memory Reserved: {peak_reserved_memory / (1024 ** 2):.2f} MB")