import time
import torch
import itertools
from PIL import Image
from pathlib import Path
from torchvision import transforms
from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
from PIL.ImageOps import exif_transpose
from diffusers.optimization import get_scheduler
from bat_lycoris_locon import BatLycorisDataset, collate_fn
import json
import random
import numpy as np
from lycoris import create_lycoris, LycorisNetwork
from lycoris.modules.locon import LoConModule
from itertools import chain

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

from accelerate import Accelerator
import torch.nn.functional as F

accelerator = Accelerator()

import argparse

parser = argparse.ArgumentParser(description="Model Training Configuration")

parser.add_argument("--pretrained_model_name_or_path", type=str, default="", help="Path to the pretrained model or model name")
parser.add_argument("--revision", type=str, default=None, help="Model revision (optional)")
parser.add_argument("--variant", type=str, default=None, help="Model variant (optional)")
parser.add_argument("--instance_path", type=str, default="", help="Path to the instance data")
parser.add_argument("--instance_prompt", type=str, default="", help="Prompt for the instance data")
parser.add_argument("--class_path", type=str, default=None, help="Path to the class data")
parser.add_argument("--class_prompt", type=str, default=None, help="Prompt for the class data")
parser.add_argument("--bat_path", type=str, default="", help="Path to the bat dataset metadata")
parser.add_argument("--bat_ratio", type=float, default=0.10, help="Ratio for the bat dataset")
parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate for the optimizer")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam optimizer beta1")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="Adam optimizer beta2")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay for Adam optimizer")
parser.add_argument("--adam_epsilon", type=float, default=1e-8, help="Epsilon for Adam optimizer")
parser.add_argument("--num_train_epochs", type=float, default=0, help="Number of training epochs") 
parser.add_argument("--max_train_steps", type=int, default=200, help="Maximum number of training steps")
parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of gradient accumulation steps")
parser.add_argument("--total_batch_size", type=int, default=None, help="Total batch size (calculated)")
parser.add_argument("--lr_scheduler_config", type=str, default="constant", help="Learning rate scheduler configuration")
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of warmup steps for learning rate scheduler")
parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of cycles for learning rate scheduler")
parser.add_argument("--lr_power", type=int, default=1, help="Power parameter for learning rate scheduler")
parser.add_argument("--bat", action="store_true", help="Whether to calculate Hessian")
parser.add_argument("--optimal_model_path", type=str, default="", help="Path to the optimal model")
parser.add_argument("--output_dir", type=str, default="", help="Output directory")
parser.add_argument("--save", action="store_true", help="Whether to save the model")
parser.add_argument("--save_hessian", action="store_true", help="Whether to save Hessian results")
parser.add_argument("--saving_step", type=int, default=10, help="Power parameter for learning rate scheduler")
parser.add_argument("--result_name", type=str, default="", help="Name for the result file")
parser.add_argument("--data_root_size", type=int, default=10, help="Size of the data root")

args = parser.parse_args()

pretrained_model_name_or_path = args.pretrained_model_name_or_path
revision = args.revision
variant = args.variant
instance_path = args.instance_path
instance_prompt = args.instance_prompt
class_path = args.class_path
class_prompt = args.class_prompt
bat_path = args.bat_path
bat_ratio = args.bat_ratio
learning_rate = args.learning_rate
adam_beta1 = args.adam_beta1
adam_beta2 = args.adam_beta2
adam_weight_decay = args.adam_weight_decay
adam_epsilon = args.adam_epsilon
num_train_epochs = args.num_train_epochs
train_batch_size = args.train_batch_size
max_train_steps = args.max_train_steps
data_root_size = args.data_root_size
gradient_accumulation_steps = args.gradient_accumulation_steps
total_batch_size = args.total_batch_size
lr_scheduler_config = args.lr_scheduler_config
lr_warmup_steps = args.lr_warmup_steps
lr_num_cycles = args.lr_num_cycles
lr_power = args.lr_power
bat = args.bat
optimal_model_path = args.optimal_model_path
output_dir = args.output_dir
save = args.save
saving_step = args.saving_step
result_name = output_dir + "/" + optimal_model_path.split("/")[-2] + "_" + optimal_model_path.split("/")[-1] + "_" + args.result_name

#make sure the output directory exists
if not os.path.exists(result_name):
    os.makedirs(result_name)

def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=None,
            use_fast=False,
)

# Load the dataset with backbone data
dataset = BatLycorisDataset(instance_path,
                            tokenizer=tokenizer,
                            instance_prompt=instance_prompt,
                            bat_data_root=bat_path,
                            bat_ratio=bat_ratio,
                            bat_data_root_size=data_root_size,
                            class_data_root=None,
                            class_prompt=None,
)  

#forward pass 
import os
from huggingface_hub import model_info
from diffusers.utils.torch_utils import is_compiled_module
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DiffusionPipeline

# Ensure deterministic behavior in DataLoader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(dataset, 
                                        batch_size=1,
                                        collate_fn=lambda x: collate_fn(x, with_prior_preservation=False), 
                                        shuffle=False, 
                                        worker_init_fn=seed_worker
)

if num_train_epochs != 0:
    max_train_steps = int(len(dataloader) * num_train_epochs)

#get vae
def model_has_vae(pretrained_model_name_or_path, revision):
    config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
    if os.path.isdir(pretrained_model_name_or_path):
        config_file_name = os.path.join(pretrained_model_name_or_path, config_file_name)
        return os.path.isfile(config_file_name) 
    else:
        files_in_repo = model_info(pretrained_model_name_or_path, revision=revision).siblings
        return any(file.rfilename == config_file_name for file in files_in_repo)

if model_has_vae(pretrained_model_name_or_path, revision):
    vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant
    )
else:
    vae = None

def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    elif model_class == "T5EncoderModel":
        from transformers import T5EncoderModel

        return T5EncoderModel
    else:
        raise ValueError(f"{model_class} is not supported.")

#get text encoder
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision)
text_encoder = text_encoder_cls.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, variant=variant
)

#get unet
set_seed(42)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet", revision=revision, variant=variant
)

initial_global_step = 0

weight_dtype = torch.float32
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")

def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
        return_dict=False,
    )
    prompt_embeds = prompt_embeds[0]

    return prompt_embeds

print(f"Accelerator num_processes: {accelerator.num_processes}")
print(f"Max train steps: {max_train_steps}")
set_seed(42)


# lycoris locon preparation
preset = {
            "enable_conv": True, 
            "target_module": [
                "Linear", 
                "Conv1d", 
                "Conv2d", 
                "Conv3d", 
            ],
            "target_name": [
                "conv_in",
                "conv_out",
                "time_embedding.linear_1",
                "time_embedding.linear_2",
            ],
        
        }

LycorisNetwork.apply_preset(preset)
lycoris_net = create_lycoris(unet, 1.0, linear_dim=16, linear_alpha=8.0, conv_dim=8, conv_alpha=1.0, algo="locon")
lycoris_net.apply_to()

# Initialize optimizer for the modified network
optimizer_class = torch.optim.AdamW

params_to_optimize = (
    itertools.chain(lycoris_net.parameters(), unet.parameters())
)

optimizer = optimizer_class(
    params_to_optimize,
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

# Initialize learning rate scheduler
lr_scheduler = get_scheduler(
    lr_scheduler_config,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * accelerator.num_processes,
    num_training_steps=max_train_steps * accelerator.num_processes,
    num_cycles=lr_num_cycles,
    power=lr_power,
)
 

unet, text_encoder, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    unet, text_encoder, optimizer, dataloader, lr_scheduler
)

unet_copy = {key: value.clone().detach() for key, value in unet.state_dict().items()}

def unwrap_model(model):
    model = accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

from tqdm.auto import tqdm

progress_bar = tqdm(
    range(0, max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not True,
)

start = time.time()

from accelerate.logging import get_logger
logger = get_logger(__name__)

hes_start = time.time()

def print_memory_stats(step):
    allocated_memory = torch.cuda.memory_allocated()
    reserved_memory = torch.cuda.memory_reserved()
    print(f"Step {step} - Allocated Memory: {allocated_memory / (1024 ** 2):.2f} MB")
    print(f"Step {step} - Reserved Memory: {reserved_memory / (1024 ** 2):.2f} MB")


#calculate hessians
if bat:
    # Move text encoder to CPU
    text_encoder.to("cpu")

    # Initial comparison of Hessian
    # Load the UNet model along with its configuration automatically from the subfolder
    our_unet = UNet2DConditionModel.from_pretrained(
        pretrained_model_name_or_path, 
        subfolder="unet"
    )

    # Load the saved weights (.pth) and apply them to the model
    unet_state_dict = torch.load(os.path.join(optimal_model_path, "locon_adm", "unet_model.pth"))
    our_unet.load_state_dict(unet_state_dict)

    # Clone the model parameters (move to the device using Accelerator)
    optimal_unet_params = [p.clone().to(accelerator.device) for p in our_unet.parameters()]

    # No longer need the model loaded just for comparison
    del our_unet

    # Compute the norm difference between original and updated parameters for both models
    unet_differences = []
    unet.load_state_dict(unet_copy)
    for original_param, updated_param in zip(optimal_unet_params, unet.parameters()):
        unet_differences.append((original_param - updated_param).view(-1))  # Flatten each tensor before concatenating

    del optimal_unet_params
    torch.cuda.empty_cache()

    # Concatenate all differences into one tensor
    unet_differences_concat = torch.cat(unet_differences)

    del unet_differences
    torch.cuda.empty_cache()

    # Compute the L2 norm of the concatenated tensor
    unet_norm_diff_init = torch.norm(unet_differences_concat).item()  / dataset.num_instance_images
    
    del unet_differences_concat

    torch.cuda.empty_cache()

    print(f"L2 Norm Difference for Initial UNet parameters: {unet_norm_diff_init}")

    hes_end = time.time()

    hes_total = hes_end - hes_start

    print(f"it took {hes_total} seconds to calculate hessian norm")

    hessian_step = 0 
    saved_models = 0

    #prepare the progress bar

    progress_bar = tqdm(
        range(0, max_train_steps),
        initial=hessian_step,
        desc="G Steps",
        # Only show the progress bar once on each machine.
        disable=not True,
    )

    unet_hessian = []
    unet_hessian.append(unet_norm_diff_init)
    print("optimal model versus backbone")

    #since our model has vae
    if vae is not None:
        vae.requires_grad_(False)
        vae.to(accelerator.device, dtype=weight_dtype)

    # Move the model to the GPU
    unet.to(accelerator.device)
    vae.to(accelerator.device) 
    # Move the model to the GPU
    text_encoder.to(accelerator.device)

    while True:
        for idx, batch in enumerate(dataloader):
            if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) < dataset.num_instance_images:
                loss = 0
                pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
                model_input = vae.encode(pixel_values).latent_dist.sample()
                model_input *= vae.config.scaling_factor
                noise = torch.randn_like(model_input)

                bsz, channels, height, width = model_input.shape
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
                
                encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"], batch["attention_mask"], text_encoder_use_attention_mask=False)
                encoder_hidden_states = encoder_hidden_states.to(accelerator.device)
                
                if unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
                
                model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
                loss = F.mse_loss(model_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    # Gradient clipping
                    params_to_clip = itertools.chain(lycoris_net.parameters(), unet.parameters())
                    accelerator.clip_grad_norm_(params_to_clip, 1)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)
                    #check training 
                    if accelerator.sync_gradients:
                        progress_bar.update(accelerator.num_processes)
                        hessian_step += accelerator.num_processes

                        # Hessian calculation
                        # more update
                        if hessian_step % dataset.num_instance_images == 0:
                            vae.to("cpu")
                            text_encoder.to("cpu")

                            our_unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
                            unet_state_dict = torch.load(os.path.join(optimal_model_path, "locon_adm", "unet_model.pth"))
                            our_unet.load_state_dict(unet_state_dict)

                            original_unet_params = [p.clone().detach().to(accelerator.device) for p in our_unet.parameters()]
                            del our_unet

                            unet_differences = []
                            for original_param, updated_param in zip(original_unet_params, unet.parameters()):
                                unet_differences.append((original_param - updated_param).view(-1))
                                
                            del original_unet_params
                            torch.cuda.empty_cache()

                            # concate all
                            unet_differences_concat = torch.cat(unet_differences)

                            del unet_differences
                            torch.cuda.empty_cache()
                            
                            # Compute the L2 norm of the concatenated tensor
                            unet_norm_diff = torch.norm(unet_differences_concat).item() / dataset.num_instance_images
                            unet_hessian.append(unet_norm_diff)
                            del unet_differences_concat
                            torch.cuda.empty_cache()

                            # Move models back to GPU
                            unet.to(accelerator.device)
                            vae.to(accelerator.device)
                            text_encoder.to(accelerator.device)

                        # save model
                        if hessian_step != 0 and hessian_step % saving_step == 0 and save:
                            lycoris_net.merge_to(1.0)
                            accelerator.wait_for_everyone()
                            if accelerator.is_main_process:
                                pipeline_args = {}

                                pipeline = DiffusionPipeline.from_pretrained(
                                    pretrained_model_name_or_path,
                                    unet=unwrap_model(unet),
                                    revision=revision,
                                    variant=variant,
                                    **pipeline_args,
                                )
                                # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                scheduler_args = {}

                                if "variance_type" in pipeline.scheduler.config:
                                    variance_type = pipeline.scheduler.config.variance_type

                                    if variance_type in ["learned", "learned_range"]:
                                        variance_type = "fixed_small"

                                    scheduler_args["variance_type"] = variance_type

                                pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                pipeline.save_pretrained(result_name + f"/adapter_model_{hessian_step}")
                                saved_models += 1

                    logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=hessian_step)

            if hessian_step >= max_train_steps:
                break

        if hessian_step >= max_train_steps:
            break


if bat:
    import copy 
    unet_hessian_copy = copy.deepcopy(unet_hessian)
    print(f"unet_hessian for G set {unet_hessian[0:5]}...")
    print("We train only adapter G set in model using Locon")
    # Restore the model parameters
    unet.load_state_dict(unet_copy)
    bat_unet_hessian = []
    bat_unet_hessian.append(unet_hessian[0])

if not bat:
    #move vae to gpu
    vae.requires_grad_(False)
    vae.to(accelerator.device, dtype=weight_dtype)
    
print("***** Running training *****")
print(f"  Num examples = {len(dataset)}")
print(f"  Num batches each epoch = {len(dataloader)}")
print(f"  Num Epochs = {num_train_epochs}")
print(f"  Instantaneous batch size per device = {train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f"  Total optimization steps = {max_train_steps}")

global_step = 0  # Initialize global step
hessian_step = 0  # Initialize hessian step
saved_models = 0

# Reset the progress bar
progress_bar = tqdm(
    range(0, max_train_steps),
    initial=global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not True,
)

if bat:
    good_data = []
    bad_data = []

lycoris_net.restore()
lycoris_net.apply_to()

while True:
        print("Resuming training")
        for idx, batch in enumerate(dataloader):
            if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) < dataset.num_instance_images:
                loss = 0
                # Ensure models and data are on the correct device
                vae.to(accelerator.device)
                unet.to(accelerator.device)

                pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
                model_input = vae.encode(pixel_values).latent_dist.sample()
                model_input = model_input * vae.config.scaling_factor
                noise = torch.randn_like(model_input)

                bsz, channels, height, width = model_input.shape
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device).long()
                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

                noisy_model_input = noisy_model_input.to(accelerator.device)
                timesteps = timesteps.to(accelerator.device)

                encoder_hidden_states = encode_prompt(
                    text_encoder,
                    batch["input_ids"],
                    batch["attention_mask"],
                    text_encoder_use_attention_mask=False,
                ).to(accelerator.device)

                if unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

                # Forward pass through the UNet model
                model_pred = unet(
                    noisy_model_input, timesteps, encoder_hidden_states).sample

                # Compute the loss
                loss = F.mse_loss(model_pred, noise)

                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(lycoris_net.parameters(), unet.parameters())
                    )
                    accelerator.clip_grad_norm_(params_to_clip, 1)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)
                
                    #check training 
                    if accelerator.sync_gradients:
                        progress_bar.update(accelerator.num_processes)
                        global_step += accelerator.num_processes

                        if global_step != 0 and global_step % saving_step == 0 and save:
                            lycoris_net.merge_to(1.0)
                            accelerator.wait_for_everyone()
                            if accelerator.is_main_process:
                                pipeline_args = {}

                                pipeline = DiffusionPipeline.from_pretrained(
                                    pretrained_model_name_or_path,
                                    unet=unwrap_model(unet),
                                    revision=revision,
                                    variant=variant,
                                    **pipeline_args,
                                )
                                # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                scheduler_args = {}

                                if "variance_type" in pipeline.scheduler.config:
                                    variance_type = pipeline.scheduler.config.variance_type

                                    if variance_type in ["learned", "learned_range"]:
                                        variance_type = "fixed_small"

                                    scheduler_args["variance_type"] = variance_type

                                pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                pipeline.save_pretrained(result_name + f"/bat_model_{global_step}")
                                saved_models += 1 

                    logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=global_step)
                        
                    #check training
                    if global_step % dataset.num_instance_images == 0 and global_step != 0 and bat:

                        #save the model
                        unet.to("cpu")
                        unet_copy = {key: value.clone().detach() for key, value in unet.state_dict().items()}
                        unet.to(accelerator.device)
                        
                        while True:
                            print("checking for bad data")
                            added_data = 0  # to track how many new data points have been added
                            for idx, batch in enumerate(dataloader):
                                if idx % (dataset.num_instance_images + (dataset.num_bat_images // 2)) >= dataset.num_instance_images:
                                    loss = 0 
                                    pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
                                    model_input = vae.encode(pixel_values).latent_dist.sample()
                                    model_input = model_input * vae.config.scaling_factor
                                    noise = torch.randn_like(model_input)
                                    bsz, channels, height, width = model_input.shape

                                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device).long()
                                    noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

                                    encoder_hidden_states = encode_prompt(
                                        text_encoder,
                                        batch["input_ids"],
                                        batch["attention_mask"],
                                        text_encoder_use_attention_mask=False,
                                    )
                                    encoder_hidden_states = encoder_hidden_states.to(accelerator.device)

                                    if unwrap_model(unet).config.in_channels == channels * 2:
                                        noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

                                    noisy_model_input = noisy_model_input.to(accelerator.device)
                                    model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample

                                    # Directly compute the loss from model output
                                    loss = F.mse_loss(model_pred, noise)

                                    accelerator.backward(loss)
                                    
                                    params_to_clip = (
                                            itertools.chain(lycoris_net.parameters(), unet.parameters())
                                        )
                                    accelerator.clip_grad_norm_(params_to_clip, 1)
                                    optimizer.step()
                                    lr_scheduler.step()
                                    optimizer.zero_grad(set_to_none=True)
                                    added_data += 1

                                # Break the loop if enough new data points have been processed
                                if added_data >= dataset.num_bat_images // 2:
                                    break

                            
                            # Check the norm difference between the original and updated parameters
                            torch.cuda.synchronize()
                            torch.cuda.empty_cache()

                            # Initial comparison of Hessian
                            # Load the UNet model along with its configuration automatically from the subfolder
                            our_unet = UNet2DConditionModel.from_pretrained(
                                pretrained_model_name_or_path, 
                                subfolder="unet"
                            )

                            # Load the saved weights (.pth) and apply them to the model
                            unet_state_dict = torch.load(os.path.join(optimal_model_path, "locon_adm", "unet_model.pth"))
                            our_unet.load_state_dict(unet_state_dict)

                            # Clone the model parameters (move to the device using Accelerator)
                            original_unet_params = [p.clone().to(accelerator.device) for p in our_unet.parameters()]

                            # No longer need the model loaded just for comparison
                            del our_unet

                            # Compute the norm difference between original and updated parameters for both models
                            unet_differences = []
                            for original_param, updated_param in zip(original_unet_params, unet.parameters()):
                                unet_differences.append((original_param - updated_param).view(-1))

                            del original_unet_params

                            # Concatenate all differences into one tensor
                            unet_differences_concat = torch.cat(unet_differences)

                            del unet_differences

                            # Compute the L2 norm of the concatenated tensor
                            new_unet_norm_diff = torch.norm(unet_differences_concat).item() / (dataset.num_instance_images + added_data )

                            del unet_differences_concat
                            unet.to("cpu")
                            torch.cuda.empty_cache()

                            unet_norm_diff = unet_hessian[0]

                            # Print norm differences
                            print(f"L2 Norm Difference for Updated UNet parameters with G: {unet_norm_diff}")

                            print(f"L2 Norm Difference for Updated UNet parameters with K: {new_unet_norm_diff}")

                            # Compare the norm differences
                            unet_norm_diff_ratio = new_unet_norm_diff - unet_norm_diff

                            # If the norm difference ratio is greater than 1, the model has diverged
                            if unet_norm_diff_ratio <= 0:
                                print("the added data has not caused the model to diverge")
                                print(f"UNet Norm Difference Ratio: {unet_norm_diff_ratio}")
                                bat_unet_hessian.append(new_unet_norm_diff)
                                #replace the original model with the new model
                                global_step += added_data
                                progress_bar.update(added_data)
                                #pop unet_hessian and text_encoder_hessian
                                unet_hessian.pop(0)
                                print(f"good data! {dataset.bat_prompts}")
                                good_data.append((f"step: {global_step}", dataset.bat_prompts))
                                dataset.give_new_bat(False)
                                msg = "Good"
                                #save the bat model                        
                                # Move the model back to the GPU
                                unet.to(accelerator.device)
                                vae.to(accelerator.device)
                                text_encoder.to(accelerator.device)
                            
                                if global_step != 0 and global_step % saving_step == 0 and save:
                                    lycoris_net.merge_to(1.0)
                                    accelerator.wait_for_everyone()
                                    if accelerator.is_main_process:
                                        pipeline_args = {}

                                        pipeline = DiffusionPipeline.from_pretrained(
                                            pretrained_model_name_or_path,
                                            unet=unwrap_model(unet),
                                            revision=revision,
                                            variant=variant,
                                            **pipeline_args,
                                        )
                                        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
                                        scheduler_args = {}

                                        if "variance_type" in pipeline.scheduler.config:
                                            variance_type = pipeline.scheduler.config.variance_type

                                            if variance_type in ["learned", "learned_range"]:
                                                variance_type = "fixed_small"

                                            scheduler_args["variance_type"] = variance_type

                                        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
                                        pipeline.save_pretrained(result_name + f"/bat_model_{global_step}")
                                        saved_models += 1
                                        print("saving model")
                                break
                            else:
                                print("the added data has caused the model to diverge")
                                print(f"UNet Norm Difference Ratio: {unet_norm_diff_ratio}")
                                # Go back to the original model
                                unet.load_state_dict(unet_copy)
                                print(f"bad data! {dataset.bat_prompts}")
                                bad_data.append((f"step: {global_step}", dataset.bat_prompts))
                                was_bad = True
                                msg = dataset.give_new_bat(was_bad)
                                if msg == "No more bat images":
                                    break
                                elif isinstance(msg, list):
                                    print("New bat images added")
                                    print(f"New bat images: {msg}")

                            # Move the model back to the GPU
                            unet.to(accelerator.device)
                            vae.to(accelerator.device)
                            text_encoder.to(accelerator.device)

                
                        if dataset.good_bat:
                            break
                else :
                    print("sync_gradients are None")
            if global_step >= max_train_steps:
                break
        if global_step >= max_train_steps:
            break


end = time.time()
total = end - start

print(f"{bat}")
#plot hessian
if bat:
    import json
    unet_result = {"unet_hessian": unet_hessian_copy, "bat_unet_hessian": bat_unet_hessian}
    good_data_result = {"good_data": good_data}
    bad_data_result = {"bad_data": bad_data}

    with open(result_name + "/unet_hessian.json", "w") as outfile:
        json.dump(unet_result, outfile)
    with open(result_name + "/good_data.json", "w") as outfile:
        json.dump(good_data_result, outfile)
    with open(result_name + "/bad_data.json", "w") as outfile:
        json.dump(bad_data_result, outfile)
    with open(os.path.join(result_name, "result.json"), "w") as outfile:
        json.dump({"unet_result": unet_result, "good_data_result": good_data_result, "bad_data_result": bad_data_result}, outfile)



if save and not bat:
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        pipeline_args = {}

        pipeline = DiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path,
            unet=unwrap_model(unet),
            revision=revision,
            variant=variant,
            **pipeline_args,
        )

        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
        scheduler_args = {}

        if "variance_type" in pipeline.scheduler.config:
            variance_type = pipeline.scheduler.config.variance_type

            if variance_type in ["learned", "learned_range"]:
                variance_type = "fixed_small"

            scheduler_args["variance_type"] = variance_type

        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
        pipeline.save_pretrained(output_dir)


# Monitor VRAM usage after the calculations
final_allocated_memory = torch.cuda.memory_allocated()
final_reserved_memory = torch.cuda.memory_reserved()

print(f"Time taken: {end - start:.6f} seconds")

# Print VRAM usage after calculations
print(f"Final CUDA Memory Allocated: {final_allocated_memory / (1024 ** 2):.2f} MB")
print(f"Final CUDA Memory Reserved: {final_reserved_memory / (1024 ** 2):.2f} MB")

# Print peak memory usage
peak_allocated_memory = torch.cuda.max_memory_allocated()
peak_reserved_memory = torch.cuda.max_memory_reserved()

print(f"Peak CUDA Memory Allocated: {peak_allocated_memory / (1024 ** 2):.2f} MB")
print(f"Peak CUDA Memory Reserved: {peak_reserved_memory / (1024 ** 2):.2f} MB")