#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import copy
import time
import random
import gc
import logging
import math
import os
import shutil
import argparse
import itertools
import torch.nn.functional as F
from pathlib import Path
from mmengine.config import Config
from functools import partial
import json

import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
from utils import *
os.environ["HF_HOME"] = "/tmp/cache"

import diffusers
from diffusers import (
    AutoencoderKL,
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
    cast_training_params,
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
    _set_state_dict_into_text_encoder
)
from diffusers.utils import (
    check_min_version,
    convert_unet_state_dict_to_peft,
    is_wandb_available,
)
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.configuration_utils import FrozenDict
from custom_dataset import Text2ImageRGBDataset, process_bbox_info
from custom_pipeline import encode_prompt

if is_wandb_available():
    import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")

logger = get_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB__SERVICE_WAIT"] = "300"
print(f"WANDB__SERVICE_WAIT: {os.environ['WANDB__SERVICE_WAIT']}")

def setup_custom_noise(config, accelerator, logger):
    """Setup and load custom noise configurations."""
    
    # Log noise generation configuration
    layout_type = "random" if config.get("random_layout", False) else "grid"
    logger.info(f"Using {layout_type} layout for noise generation")
    
    if config.get("gaussian_kernel", False):
        weight = config.get('weight', 0)
        scale_factor = config.get('scale_factor', 0)
        logger.info(f"Using gaussian kernel: weight={weight}, scale_factor={scale_factor}")
    
    # Log filtering configuration
    iou_thresh = config.get("iou_thresh", 2)
    contain_thresh = config.get("contain_thresh", 2)
    if iou_thresh < 2 or contain_thresh < 2:
        logger.info(f"Using bbox filtering: iou_thresh={iou_thresh}, contain_thresh={contain_thresh}")
    
    # Log bbox normalization
    if config.get("bbox_normalized", False):
        beta = config.get('beta', 0)
        logger.info(f"Using bbox_normalized with beta={beta}")
    
    # Log gamma parameter
    if config.get("gamma", None):
        logger.info(f"Using gamma={config.gamma} for uniform rectangle noise")
    
    # Load global custom noise
    noise_custom = None
    if hasattr(config, 'custom_noise_path'):
        try:
            noise_path = config.custom_noise_path
            noise_custom = torch.load(noise_path, map_location=accelerator.device)
            is_random = 'random' in noise_path
            logger.info(f"Loaded global noise: shape={noise_custom.shape}, dtype={noise_custom.dtype}, random={is_random}")
        except Exception as e:
            logger.warning(f"Failed to load global noise: {e}")
    
    # Load partial custom noise
    fixed_noise_partial = None
    if hasattr(config, 'custom_noise_partial_path'):
        try:
            partial_path = config.custom_noise_partial_path
            fixed_noise_partial = torch.load(partial_path, map_location=accelerator.device)
            logger.info(f"Loaded partial noise from: {partial_path}")
        except Exception as e:
            logger.warning(f"Failed to load partial noise: {e}")
    
    return noise_custom, fixed_noise_partial

def generate_global_noise(batch, model_input, noise_custom, config):
    """Generate noise using global noise patterns."""
    num_noise_patterns = noise_custom.shape[0]
    noise = torch.zeros_like(model_input)
    
    for index, item in enumerate(batch['label']):
        number_item = int(item)
        
        if config.get("random_sample", False):
            noise_index = random.randint(0, num_noise_patterns-1)
        else:
            if num_noise_patterns > 50:
                item_per_cls = num_noise_patterns // 50
                random_index = random.randint(0, item_per_cls-1)
                noise_index = number_item * item_per_cls + random_index
            else:
                noise_index = number_item % num_noise_patterns
        
        noise[index] = noise_custom[noise_index]
    
    return noise

def generate_grid_layout_noise(batch, model_input, fixed_noise_partial, config):
    """Generate noise using grid layout strategy."""
    spatial_dim = model_input.shape[-1]
    noise = torch.randn_like(model_input)
    
    if fixed_noise_partial is not None:
        fixed_noise_partial = fixed_noise_partial.to(dtype=model_input.dtype, device=model_input.device)
    
    for index, item in enumerate(batch['label']):
        number_item = int(item)
        row = number_item // 7
        col = number_item % 7
        block_size = spatial_dim // 7
        start_x = col * block_size
        start_y = row * block_size
        end_x = start_x + block_size
        end_y = start_y + block_size
        
        if fixed_noise_partial is not None:
            # Apply fixed partial noise block
            noise[index][:, start_y:end_y, start_x:end_x] = fixed_noise_partial[0]
        elif config.get("gaussian_kernel", False):
            # Apply gaussian kernel noise
            bbox_info = [[start_x, start_y, end_x, end_y]]
            centers, scales = bbox_to_centers_and_scales(
                bbox_info, 
                image_size=(int(spatial_dim), int(spatial_dim)), 
                scale_factor=config.scale_factor
            )
            mean_scale = torch.tensor(scales).mean().item()
            noise_gmm = generate_uniform_gmm_noise(
                noise=noise[index].unsqueeze(0),
                centers=centers,
                scale=mean_scale,
                weight=config.weight,
            )
            noise[index] = noise_gmm[0]
        else:
            # Apply resample noise
            noise_index = torch.randn_like(model_input[0]) * 0.1
            noise[index][:, start_y:end_y, start_x:end_x] = noise_index[:, start_y:end_y, start_x:end_x]
    
    return noise

def generate_random_layout_noise(batch, model_input, fixed_noise_partial, config):
    """Generate noise using random layout strategy."""
    noise = torch.randn_like(model_input)
    
    if config.get("gaussian_kernel", False):
        return _generate_gaussian_kernel_noise(batch, noise, config)
    else:
        return _generate_bbox_based_noise(batch, noise, model_input, fixed_noise_partial, config)

def _generate_gaussian_kernel_noise(batch, noise, config):
    """Generate gaussian kernel based noise for random layout."""
    for batch_index, item in enumerate(batch["bbox_info"]):
        bbox_info = [box_info['bbox'] for box_info in item]
        
        if config.get("anisotropic", False):
            # Filter bboxes if needed
            iou_thresh = config.get("iou_thresh", 2)
            contain_thresh = config.get("contain_thresh", 2)
            if iou_thresh < 2 or contain_thresh < 2:
                bbox_info, _ = keep_subject_boxes_iou_contain(
                    bbox_info, iou_thresh=iou_thresh, contain_thresh=contain_thresh
                )
            
            # Generate anisotropic noise
            centers, scales = bbox_to_centers_and_scales_anisotropic(
                bbox_info, 
                image_size=(config.resolution, config.resolution), 
                scale_factor=config.scale_factor
            )
            
            if config.get("bbox_normalized", False):
                beta = config.get("beta", 0.4)
                noise_gmm = generate_anisotropic_gmm_noise_bbox_normalized(
                    noise=noise[batch_index].unsqueeze(0),
                    centers=centers,
                    sigmas=scales,
                    weight=config.get("weight", 0.3),
                    beta=beta,
                    energy_norm=config.get("energy_norm", True)
                )
            else:
                noise_gmm = generate_anisotropic_gmm_noise(
                    noise=noise[batch_index].unsqueeze(0),
                    centers=centers,
                    sigmas=scales,
                    weight=config.get("weight", 0.3),
                    energy_norm=config.get("energy_norm", True)
                )
            noise[batch_index] = noise_gmm[0]
        else:
            # Generate uniform gaussian noise
            centers, scales = bbox_to_centers_and_scales(
                bbox_info, 
                image_size=(config.resolution, config.resolution), 
                scale_factor=config.scale_factor
            )
            mean_scale = torch.tensor(scales).mean().item()
            noise_gmm = generate_uniform_gmm_noise(
                noise=noise[batch_index].unsqueeze(0),
                centers=centers,
                scale=mean_scale,
                weight=config.weight,
            )
            noise[batch_index] = noise_gmm[0]
    
    return noise

def _generate_bbox_based_noise(batch, noise, model_input, fixed_noise_partial, config):
    """Generate bbox-based noise for random layout."""
    if fixed_noise_partial is not None:
        fixed_noise_partial = fixed_noise_partial.to(dtype=model_input.dtype, device=model_input.device)
    
    for batch_index, item in enumerate(batch['bbox_info']):
        # Generate random noise for each sample
        gamma = config.get("gamma", 0.1)
        noise_index = torch.randn_like(model_input[0]) * gamma
        
        for box_info in item:
            # Convert and validate coordinates
            box = box_info['bbox']
            x_min, y_min, x_max, y_max = box
            x_min = max(0, min(int(x_min * 64/512), 64))
            y_min = max(0, min(int(y_min * 64/512), 64))
            x_max = max(0, min(int(x_max * 64/512), 64))
            y_max = max(0, min(int(y_max * 64/512), 64))
            
            if fixed_noise_partial is not None:
                # Apply partial fixed custom noise
                box_h, box_w = y_max - y_min, x_max - x_min
                patch, box_patch_y1, box_patch_y2, box_patch_x1, box_patch_x2 = get_centered_patch_and_coords(
                    fixed_noise_partial, box_h, box_w
                )
                noise[batch_index][:, y_min + box_patch_y1:y_min + box_patch_y2, 
                                    x_min + box_patch_x1:x_min + box_patch_x2] = patch
            else:
                # Apply partial resample custom noise
                noise[batch_index][:, y_min:y_max, x_min:x_max] = noise_index[:, y_min:y_max, x_min:x_max]
    
    return noise

def parse_config(path=None):
    
    if path is None:
        parser = argparse.ArgumentParser()
        parser.add_argument('config_dir', type=str)
        args = parser.parse_args()
        path = args.config_dir
    config = Config.fromfile(path)
    
    config.config_dir = path

    if "LOCAL_RANK" in os.environ:
        config.local_rank = int(os.environ["LOCAL_RANK"])
    elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
        config.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
    else:
        config.local_rank = -1

    return config

def load_text_encoders(class_one, class_two):
    text_encoder_one = class_one.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision, variant=config.variant,
        cache_dir=config.get("cache_dir", None),
    )
    text_encoder_two = class_two.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision, variant=config.variant,
        cache_dir=config.get("cache_dir", None),
    )

    return text_encoder_one, text_encoder_two

def import_model_class_from_model_name_or_path(
    pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path, subfolder=subfolder, revision=revision,
        cache_dir=config.get("cache_dir", None),
    )
    model_class = text_encoder_config.architectures[0]
    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "T5EncoderModel":
        from transformers import T5EncoderModel

        return T5EncoderModel
    else:
        raise ValueError(f"{model_class} is not supported.")

def unwrap_model(model, accelerator):
    model = accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

def save_model_hook_partial(models, weights, output_dir, accelerator, transformer, text_encoder_one):
    if accelerator.is_main_process and len(weights) > 0:
        transformer_lora_layers_to_save = None
        text_encoder_one_lora_layers_to_save = None

        for model in models:
            if isinstance(model, type(unwrap_model(transformer, accelerator))):
                transformer_lora_layers_to_save = get_peft_model_state_dict(model)
            elif isinstance(model, type(unwrap_model(text_encoder_one, accelerator))):
                text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
            else:
                raise ValueError(f"unexpected save model: {model.__class__}")

            # make sure to pop weight so that corresponding model is not saved again
            weights.pop()

        FluxPipeline.save_lora_weights(
            output_dir,
            transformer_lora_layers=transformer_lora_layers_to_save,
            text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
        )


def load_model_hook_partial(models, input_dir, accelerator, transformer, text_encoder_one):
    if len(models) > 0:
        transformer_ = None
        text_encoder_one_ = None

        while len(models) > 0:
            model = models.pop()

            if isinstance(model, type(unwrap_model(transformer, accelerator))):
                transformer_ = model
            elif isinstance(model, type(unwrap_model(text_encoder_one, accelerator))):
                text_encoder_one_ = model
            else:
                raise ValueError(f"unexpected save model: {model.__class__}")

        lora_state_dict = FluxPipeline.lora_state_dict(input_dir)

        transformer_state_dict = {
            f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
        }
        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
        if incompatible_keys is not None:
            # check only for unexpected keys
            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
            if unexpected_keys:
                logger.warning(
                    f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                    f" {unexpected_keys}. "
                )

        if config.train_text_encoder:
            # Do we need to call `scale_lora_layers()` here?
            _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

        # Make sure the trainable params are in float32. This is again needed since the base models
        # are in `weight_dtype`. More details:
        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
        if config.mixed_precision in ["fp16", "bf16"]:
            models = [transformer_]
            if config.train_text_encoder:
                models.extend([text_encoder_one_])
            # only upcast trainable parameters (LoRA) into fp32
            cast_training_params(models)


def initialize_all_models(config, accelerator):

    # Load the tokenizers
    logger.info(f"[INFO] start load tokenizers")
    tokenizer_one = CLIPTokenizer.from_pretrained(
        config.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=config.revision,
        cache_dir=config.get("cache_dir", None),
    )
    tokenizer_two = T5TokenizerFast.from_pretrained(
        config.pretrained_model_name_or_path,
        subfolder="tokenizer_2",
        revision=config.revision,
        cache_dir=config.get("cache_dir", None),
    )
    
    # import correct text encoder classes
    logger.info(f"[INFO] start load text encoders")
    text_encoder_cls_one = import_model_class_from_model_name_or_path(
        config.pretrained_model_name_or_path, config.revision
    )
    text_encoder_cls_two = import_model_class_from_model_name_or_path(
        config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2"
    )
    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)

    logger.info(f"[INFO] start load vae")
    vae: AutoencoderKL = AutoencoderKL.from_pretrained(
        config.pretrained_model_name_or_path,
        subfolder="vae",
        revision=config.revision,
        variant=config.variant,
        cache_dir=config.get("cache_dir", None),
    )
    vae.enable_slicing()

    logger.info(f"[INFO] start load mmdit")
    transformer = FluxTransformer2DModel.from_pretrained(
        config.transformer_varient if hasattr(config, "transformer_varient") else config.pretrained_model_name_or_path, 
        subfolder="" if hasattr(config, "transformer_varient") else "transformer", 
        revision=config.revision, 
        variant=config.variant,
        cache_dir=config.get("cache_dir", None),
    )

    # lora pretrained lora weights
    if hasattr(config, "pretrained_lora_dir"):
        lora_state_dict = FluxPipeline.lora_state_dict(config.pretrained_lora_dir)
        FluxPipeline.load_lora_into_transformer(lora_state_dict, None, transformer)
        transformer.fuse_lora(safe_fusing=True)
        transformer.unload_lora() # don't forget to unload the lora params
        logger.info(f"[INFO] fused pretrained lora weights from {config.pretrained_lora_dir}")


    vae.requires_grad_(False)
    transformer.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)

    logger.info(f"[INFO] move models to cuda")
    vae.to(accelerator.device, dtype=config.weight_dtype)
    transformer.to(accelerator.device, dtype=config.weight_dtype)
    text_encoder_one.to(accelerator.device, dtype=config.weight_dtype)
    text_encoder_two.to(accelerator.device, dtype=config.weight_dtype)
    
    if config.gradient_checkpointing:
        transformer.enable_gradient_checkpointing()
        if config.train_text_encoder:
            text_encoder_one.gradient_checkpointing_enable()

    # now we will add new LoRA weights to the attention layers
    logger.info(f"[INFO] add lora in mmdit")
    target_modules = []
    # transformer_blocks
    module_names = ["to_k", "to_q", "to_v", "to_out.0"]
    for name, _ in transformer.transformer_blocks.named_modules():
        if any([name.endswith(n) for n in module_names]):
            target_modules.append("transformer_blocks." + name)
    # single_transformer_blocks
    module_names = ["to_k", "to_q", "to_v"]
    for name, _ in transformer.single_transformer_blocks.named_modules():
        if any([name.endswith(n) for n in module_names]):
            target_modules.append("single_transformer_blocks." + name)

    transformer_lora_config = LoraConfig(
        r=config.rank,
        lora_alpha=config.rank,
        init_lora_weights="gaussian",
        target_modules=target_modules,
    )
    transformer.add_adapter(transformer_lora_config)

    if config.train_text_encoder:
        text_lora_config = LoraConfig(
            r=config.text_encoder_rank,
            lora_alpha=config.text_encoder_rank,
            init_lora_weights="gaussian",
            target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
        )
        text_encoder_one.add_adapter(text_lora_config)

    tokenizers = [tokenizer_one, tokenizer_two]
    text_encoders = [text_encoder_one, text_encoder_two]

    # Make sure the trainable params are in float32.  
    logger.info(f"[INFO] cast_training_params to fp32")
    if config.mixed_precision in ["fp16", "bf16"]:
        models = [transformer]
        if config.train_text_encoder:
            models.extend([text_encoder_one])
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params(models)
    
    # Return all models
    return vae, transformer, tokenizers, text_encoders

def get_trainable_params(config, accelerator, transformer, text_encoder_one):

    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))

    # Optimization parameters
    transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": config.learning_rate}
    params_to_optimize = [transformer_parameters_with_lr]

    if config.train_text_encoder:
        text_encoder_one_params = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
        text_parameters_one_with_lr = {
            "params": text_encoder_one_params,
            "weight_decay": config.adam_weight_decay_text_encoder,
            "lr": config.text_encoder_lr,
        }
        params_to_optimize.extend([text_parameters_one_with_lr])
    
    
    if accelerator.is_main_process:
        for i, param_set in enumerate(params_to_optimize):
            num_params = sum([p.numel() for p in param_set["params"]]) / 1e+6
            print(f"Trainable Params Set {i}: {num_params:02f}M")
    
    return params_to_optimize

def get_sigmas(timesteps, accelerator, noise_scheduler_copy, n_dim=4, dtype=torch.float32):
    sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
    schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
    timesteps = timesteps.to(accelerator.device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma    

def log_validation(
    pipeline,
    config,
    accelerator,
    global_step
):
    logger.info(f"Running validation... \n Generating {config.num_validation_images} images per case")
    pipeline = pipeline.to(accelerator.device)
    # pipeline.set_progress_bar_config(disable=True)

    # run inference
    image_logs = []
    generator = torch.Generator(device=accelerator.device).manual_seed(config.seed) if config.seed else None
    for validation_prompt in config.validation_prompts:
        with torch.autocast(accelerator.device.type, dtype=torch.bfloat16):
            images = [
                pipeline(
                    prompt=validation_prompt,
                    generator=generator,
                    height=config.resolution,
                    width=config.resolution,
                ).images[0]
                for _ in range(config.num_validation_images)
            ]
        image_logs.append(
            {
                "images": images, 
                "caption": validation_prompt,
            }
        )

    for tracker in accelerator.trackers:
        assert tracker.name == "wandb"
        formatted_images = []
        
        for log in image_logs:
            images = log["images"]
            validation_prompt = log["caption"]
            for idx, image in enumerate(images):
                image = wandb.Image(image, caption=validation_prompt)
                formatted_images.append(image)

        tracker.log({"validation": formatted_images})

    del pipeline
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def collate_fn(examples):
    pixel_values = torch.stack([example["image_pt"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    caption = [example["caption"] for example in examples]
    label = [example["label"] for example in examples]
    bbox_info = [example["bbox_info"] for example in examples]
    result = {
        "pixel_values": pixel_values,
        "caption": caption,
        "label": label,
        "bbox_info": bbox_info,
    }
    return result


def custom_collate_fn(batch):
    # First, filter out None samples from the batch
    batch = [b for b in batch if b not in ({}, None)]

    # If the batch is empty after filtering, return an empty batch or handle accordingly
    if len(batch) == 0:
        return {}  # Return an empty dictionary or other appropriate value for empty batches

    # Ensure every sample in the batch is a valid dict and check for missing keys or None values
    for i, b in enumerate(batch):
        if isinstance(b, dict):  # Check if the sample is a dictionary
            for key, value in b.items():
                if value is None:
                    print(f"Warning: Sample {i}, key '{key}' has None value!")
                    batch[i] = None
                    break

    # Filter out the `None` entries after checking for problematic ones
    batch = [b for b in batch if b is not None]

    # If the batch is empty after the check, return an empty dictionary
    if len(batch) == 0:
        return {}  # Important: Return empty dictionary to avoid error

    return torch.utils.data._utils.collate.default_collate(batch)


def create_custom_noise_dict(num_classes=50, latent_channels=16, height=64, width=64, device="cuda"):

    noise_dict = {}
    # Set random seed to ensure consistent noise generation
    torch.manual_seed(config.seed)
    
    for i in range(num_classes):
        # Create a fixed noise pattern for each class
        # Using normal distribution to generate noise, can be modified as needed
        noise = torch.randn(1, latent_channels, height, width, device=device)
        noise_dict[i+1] = noise
    
    return noise_dict

def get_centered_patch_and_coords(fixed_noise_partial, box_h, box_w):

    _, C, H, W = fixed_noise_partial.shape
    device = fixed_noise_partial.device
    dtype = fixed_noise_partial.dtype

    # Calculate actual patch size
    patch_h = min(H, box_h)
    patch_w = min(W, box_w)

    # Start and end coordinates of center region within box
    box_patch_y1 = (box_h - patch_h) // 2
    box_patch_x1 = (box_w - patch_w) // 2
    box_patch_y2 = box_patch_y1 + patch_h
    box_patch_x2 = box_patch_x1 + patch_w

    # Start and end coordinates of center region within fixed_noise_partial
    fixed_patch_y1 = (H - patch_h) // 2
    fixed_patch_x1 = (W - patch_w) // 2
    fixed_patch_y2 = fixed_patch_y1 + patch_h
    fixed_patch_x2 = fixed_patch_x1 + patch_w

    patch = fixed_noise_partial[0, :, fixed_patch_y1:fixed_patch_y2, fixed_patch_x1:fixed_patch_x2]  # [C, patch_h, patch_w]
    return patch, box_patch_y1, box_patch_y2, box_patch_x1, box_patch_x2

def train(
        accelerator, progress_bar, first_epoch, global_step,
        vae, transformer, text_encoder_one, text_encoder_two, text_encoders, tokenizers,
        noise_scheduler_copy, optimizer, lr_scheduler, train_dataloader, config,
    ):
    global noise_custom
    global fixed_noise_partial

    # Track timesteps for current interval only
    current_timesteps = []
    # Record the step number of last checkpoint
    last_checkpoint_step = global_step
    
    for epoch in range(first_epoch, config.num_train_epochs):
        transformer.train()
        
        if config.train_text_encoder:
            text_encoder_one.train()
            # set top parameter requires_grad = True for gradient checkpointing works
            accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)

        for step, batch in enumerate(train_dataloader):
            models_to_accumulate = [transformer]
            if config.train_text_encoder:
                models_to_accumulate += [text_encoder_one]
            with accelerator.accumulate(models_to_accumulate):
                merged_pt = batch['pixel_values'].to(dtype=vae.dtype)

                # For RGB image processing
                pixel_values_vae_input = merged_pt.to(accelerator.device)[:, :3] 

                prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
                    text_encoders=text_encoders,
                    tokenizers=tokenizers,
                    prompt=batch["caption"],
                )

                # Convert images to latent space
                model_input = vae.encode(pixel_values_vae_input).latent_dist.sample()
                model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
                model_input = model_input.to(dtype=config.weight_dtype)
                vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
                latent_image_ids = FluxPipeline._prepare_latent_image_ids(
                    model_input.shape[0],
                    model_input.shape[2],
                    model_input.shape[3],
                    accelerator.device,
                    config.weight_dtype,
                )

                # Generate noise based on strategy
                if noise_custom is not None:
                    noise = generate_global_noise(batch, model_input, noise_custom, config)
                elif config.get("random_layout", False):
                    noise = generate_random_layout_noise(batch, model_input, fixed_noise_partial, config)
                else:
                    noise = generate_grid_layout_noise(batch, model_input, fixed_noise_partial, config)

                bsz = model_input.shape[0]

                # Sample a random timestep for each image
                # for weighting schemes where we sample timesteps non-uniformly
                u = compute_density_for_timestep_sampling(
                    weighting_scheme=config.weighting_scheme,
                    batch_size=bsz,
                    logit_mean=config.logit_mean,
                    logit_std=config.logit_std,
                    mode_scale=config.mode_scale,
                )

                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
                
                if accelerator.is_main_process:
                    current_timesteps.extend(timesteps.cpu().numpy().tolist())

                # Add noise according to flow matching.
                sigmas = get_sigmas(timesteps, accelerator, noise_scheduler_copy, n_dim=model_input.ndim, dtype=model_input.dtype)
                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

                packed_noisy_model_input = FluxPipeline._pack_latents(
                    noisy_model_input,
                    batch_size=model_input.shape[0],
                    num_channels_latents=model_input.shape[1],
                    height=model_input.shape[2],
                    width=model_input.shape[3],
                )

                # handle guidance
                if config.get("guidance_scale", None):
                    guidance = torch.tensor([config.guidance_scale], device=accelerator.device)
                    guidance = guidance.expand(model_input.shape[0])
                else:
                    guidance = None
                

                # Predict the noise residual
                model_pred = transformer(
                    hidden_states=packed_noisy_model_input,
                    timestep=timesteps / 1000,
                    guidance=guidance,
                    encoder_hidden_states=prompt_embeds,
                    pooled_projections=pooled_prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    return_dict=False,
                )[0]
                
                model_pred = FluxPipeline._unpack_latents(
                    model_pred,
                    height=int(model_input.shape[2] * vae_scale_factor / 2),
                    width=int(model_input.shape[3] * vae_scale_factor / 2),
                    vae_scale_factor=vae_scale_factor,
                )

                if config.get("precondition_outputs", None):
                    model_pred = model_pred * (-sigmas) + noisy_model_input

                # these weighting schemes use a uniform timestep sampling
                # and instead post-weight the loss
                weighting = compute_loss_weighting_for_sd3(weighting_scheme=config.weighting_scheme, sigmas=sigmas)

                # flow matching loss
                if config.get("precondition_outputs", None):
                    target = model_input
                else:
                    target = noise - model_input

                # Compute loss
                loss = torch.mean(
                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
                    1,
                )
                loss = loss.mean()
                

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(transformer.parameters(), text_encoder_one.parameters())
                        if config.train_text_encoder
                        else transformer.parameters()
                    )
                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                
                if accelerator.is_main_process or accelerator.state.distributed_type == DistributedType.DEEPSPEED:
                    if global_step % config.checkpointing_steps == 0:
                        if accelerator.is_main_process:
                            # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                            if config.checkpoints_total_limit is not None:
                                checkpoints = os.listdir(config.output_dir)
                                checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                                checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                                # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                                if len(checkpoints) >= config.checkpoints_total_limit:
                                    num_to_remove = len(checkpoints) - config.checkpoints_total_limit + 1
                                    removing_checkpoints = checkpoints[0:num_to_remove]

                                    logger.info(
                                        f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                    )
                                    logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                                    for removing_checkpoint in removing_checkpoints:
                                        removing_checkpoint = os.path.join(config.output_dir, removing_checkpoint)
                                        shutil.rmtree(removing_checkpoint)

                        save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                        if accelerator.state.distributed_type == DistributedType.DEEPSPEED:
                            extra_kwargs = {'exclude_frozen_parameters': True}
                        else:
                            extra_kwargs = {}
                        accelerator.save_state(save_path, **extra_kwargs)
                        logger.info(f"[Rank{accelerator.process_index}] saved state to {save_path}", main_process_only=False)

                        if accelerator.is_main_process and accelerator.state.distributed_type == DistributedType.DEEPSPEED:
                            transformer_lora_layers_to_save = get_peft_model_state_dict(accelerator.unwrap_model(transformer))
                            text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(accelerator.unwrap_model(text_encoder_one)) if config.train_text_encoder else None
                            FluxPipeline.save_lora_weights(
                                save_path,
                                transformer_lora_layers=transformer_lora_layers_to_save,
                                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
                            )
                        
                logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)

                if accelerator.is_main_process and global_step % config.checkpointing_steps == 0:
                    # Create checkpoint directory
                    checkpoint_dir = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    
                    # Save timesteps for current interval
                    timestep_save_path = os.path.join(checkpoint_dir, f"timesteps_{last_checkpoint_step}_to_{global_step}.npy")
                    np.save(timestep_save_path, np.array(current_timesteps))
                    
                    logger.info(f"Saved timesteps from step {last_checkpoint_step} to {global_step} at {timestep_save_path}")
                    
                    # Update last checkpoint step
                    last_checkpoint_step = global_step
                    
                    # Clear timesteps list for next interval
                    current_timesteps = []

                if global_step >= config.max_train_steps:
                    break

                if accelerator.is_main_process:
                    if config.validation_prompts is not None and (global_step % config.validation_steps == 0 or global_step == 1):
                        # create pipeline
                        pipeline = FluxPipeline.from_pretrained(
                            config.pretrained_model_name_or_path,
                            vae=vae,
                            text_encoder=accelerator.unwrap_model(text_encoder_one),
                            text_encoder_2=accelerator.unwrap_model(text_encoder_two),
                            transformer=accelerator.unwrap_model(transformer),
                            revision=config.revision,
                            variant=config.variant,
                            torch_dtype=config.weight_dtype,
                        )
                        log_validation(
                            pipeline=pipeline,
                            config=config,
                            accelerator=accelerator,
                            global_step=global_step,
                        )
                        torch.cuda.empty_cache()
                        gc.collect()

    # Save final timesteps after training ends
    if accelerator.is_main_process and len(current_timesteps) > 0:
        final_dir = os.path.join(config.output_dir, "checkpoint-final")
        os.makedirs(final_dir, exist_ok=True)
        
        # Save timesteps for final interval
        timestep_save_path = os.path.join(final_dir, f"timesteps_{last_checkpoint_step}_to_final.npy")
        np.save(timestep_save_path, np.array(current_timesteps))
        
        logger.info(f"Saved final timesteps from step {last_checkpoint_step} to end at {timestep_save_path}")

def main(config):
    if 'basecode_flux' in config.output_dir:
        config.output_dir = config.output_dir.replace('Count-FLUX/basecode_flux', 'workspace/cache')
    print(f"config.output_dir: {config.output_dir}")
    logging_dir = Path(config.output_dir, config.logging_dir)

    accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
        log_with=config.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )

    # Disable AMP for MPS. 
    if torch.backends.mps.is_available():
        accelerator.native_amp = False

    if config.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if config.seed is not None:
        set_seed(config.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)

    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    config.weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        config.weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        config.weight_dtype = torch.bfloat16

    if torch.backends.mps.is_available() and config.weight_dtype == torch.bfloat16:
        # due to pytorch#99272, MPS does not yet support bfloat16.
        raise ValueError(
            "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
        )

    # Load scheduler
    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        config.pretrained_model_name_or_path, subfolder="scheduler",
        cache_dir=config.cache_dir if hasattr(config, "cache_dir") else None
    )
    noise_scheduler_copy = copy.deepcopy(noise_scheduler)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32 and torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True

    if config.scale_lr:
        config.learning_rate = (
            config.learning_rate * config.gradient_accumulation_steps * config.train_batch_size * accelerator.num_processes
        )

    # Initialize all models
    vae, transformer, tokenizers, text_encoders = initialize_all_models(config, accelerator)
    text_encoder_one, text_encoder_two = text_encoders

    save_model_hook = partial(
        save_model_hook_partial,
        accelerator=accelerator,
        transformer=transformer,
        text_encoder_one=text_encoder_one,
    )
    load_model_hook = partial(
        load_model_hook_partial,
        accelerator=accelerator,
        transformer=transformer,
        text_encoder_one=text_encoder_one,
    )
    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    # Get trainable parameters
    params_to_optimize = get_trainable_params(
        config, 
        accelerator, 
        transformer, 
        text_encoder_one, 
    )

    # Optimizer
    if config.get("optimizer", None) == "prodigy":
        import prodigyopt # type: ignore
        optimizer_class = prodigyopt.Prodigy
        if config.train_text_encoder and config.text_encoder_lr:
            params_to_optimize[1]["lr"] = config.learning_rate
        optimizer = optimizer_class(
            params_to_optimize,
            lr=config.learning_rate,
            betas=(config.adam_beta1, config.adam_beta2),
            beta3=config.prodigy_beta3,
            weight_decay=config.adam_weight_decay,
            eps=config.adam_epsilon,
            decouple=config.prodigy_decouple,
            use_bias_correction=config.prodigy_use_bias_correction,
            safeguard_warmup=config.prodigy_safeguard_warmup,
        )
    else:
        if config.get("use_8bit_adam", None):
            import bitsandbytes as bnb # type: ignore
            optimizer_class = bnb.optim.AdamW8bit
        else:
            optimizer_class = torch.optim.AdamW
        optimizer = optimizer_class(
            params_to_optimize,
            betas=(config.adam_beta1, config.adam_beta2),
            weight_decay=config.adam_weight_decay,
            eps=config.adam_epsilon,
        )

    train_dataset = Text2ImageRGBDataset(**config.dataset_cfg)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=config.dataloader_shuffle,
        pin_memory=config.dataloader_pin_memory,
        drop_last=config.dataloader_drop_last,
        collate_fn=collate_fn,
        batch_size=config.train_batch_size,
        num_workers=config.dataloader_num_workers,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    if config.max_train_steps is None:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        config.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=config.max_train_steps * accelerator.num_processes,
        num_cycles=config.lr_num_cycles,
        power=config.lr_power,
    )

    # Prepare everything with our `accelerator`.
    if accelerator.state.distributed_type == DistributedType.DEEPSPEED:
        accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = config.train_batch_size # mute annoying deepspeed errors
    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        transformer, optimizer, train_dataloader, lr_scheduler
    )
    if config.train_text_encoder:
        text_encoder_one = accelerator.prepare(text_encoder_one)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    if overrode_max_train_steps:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    config.num_train_epochs = math.ceil(config.max_train_steps / num_update_steps_per_epoch)

    # Initialize trackers, also store our configuration.
    if accelerator.is_main_process:
        tracker_config = dict(copy.deepcopy(config))
        accelerator.init_trackers(
            project_name=config.tracker_project_name, 
            config=tracker_config, 
            init_kwargs={"wandb": {"name": config.wandb_job_name}},
        )

    # Train!
    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {config.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {config.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {config.max_train_steps}")
    global_step = 0
    first_epoch = 0

    if config.base_checkpoint:
        accelerator.print(f"Resuming from checkpoint {config.base_checkpoint}")
        if accelerator.state.distributed_type == DistributedType.DEEPSPEED:
            extra_kwargs = {'load_module_strict': False}
        else:
            extra_kwargs = {}
        accelerator.load_state(config.base_checkpoint, **extra_kwargs)
            
    # Potentially load in the weights and states from a previous save
    if config.resume_from_checkpoint and config.base_checkpoint is None:
        if config.resume_from_checkpoint != "latest":
            path = os.path.basename(config.resume_from_checkpoint)
        else:
            # Get the mos recent checkpoint
            dirs = os.listdir(config.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            config.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            if accelerator.state.distributed_type == DistributedType.DEEPSPEED:
                extra_kwargs = {'load_module_strict': False}
            else:
                extra_kwargs = {}
            accelerator.load_state(os.path.join(config.output_dir, path), **extra_kwargs)
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch

    else:
        initial_global_step = 0

    progress_bar = tqdm(
        range(0, config.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    # Initialize noise configuration and load pre-generated noise
    global noise_custom, fixed_noise_partial
    noise_custom, fixed_noise_partial = setup_custom_noise(config, accelerator, logger)
    

    train(
        accelerator, progress_bar, first_epoch, global_step,
        vae, transformer, text_encoder_one, text_encoder_two, text_encoders, tokenizers,
        noise_scheduler_copy, optimizer, lr_scheduler, train_dataloader, config,
    )

    # finally, save the lora layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        save_path = os.path.join(config.output_dir, f"checkpoint-final")
        accelerator.save_state(save_path)
        logger.info(f"Final checkpoints is saved to {save_path}")

    accelerator.end_training()


if __name__ == "__main__":
    config = parse_config()
    main(config)
