# Copyright (c) 2025, [NAME] [NAME]. All rights reserved.
#
# This work is licensed under APACHE LICENSE, VERSION 2.0
# You should have received a copy of the license along with this
# work. If not, see [URL]

"""Distill DiT-based diffusion/flow-matching models using the SiD few-step techniques described in the
paper "[NAME] [NAME] of Flow"."""


"""Main training loop."""

import os
import time
import copy
import json
import pickle
import psutil
from typing import Optional
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
from metrics import sid_metric_main as metric_main
from training.sid_dit_util import load_dit, sid_dit_generate, sid_dit_denoise
from functools import partial
# Import ShardedGradScaler for gradient scaling with autocast+fp16.
# Note: By default, this code uses autocast with bfloat16 (bf16), so ShardedGradScaler is not actively used.
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
import gc

from training.precompute_latents import (
    process_and_save_latents_with_pipeline_multigpu,
)

from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
)
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    FullStateDictConfig,
    FullOptimStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
)

# Data loading setup:
# Multiple dataloaders are supported:
#   - Training with text-only datasets
#   - Training with paired text+image datasets
#   - Evaluation mode (e.g., COCO2014 dataset for FID/CLIP metrics)
# The appropriate loader is selected based on the training/evaluation configuration.
from training.precompute_latents import Text2ImageDataset, TextDataset

# Suppress NCCL INFO messages
if 'NCCL_DEBUG' not in os.environ:
    os.environ['NCCL_DEBUG'] = 'WARN'
    
#----------------------------------------------------------------------------

def setup_snapshot_image_grid(training_set, random_seed=0):
    rnd = np.random.RandomState(random_seed)
    gw = np.clip(3840 // training_set.resolution, 4, 32)
    gh = np.clip(2160 // training_set.resolution, 2, 32)

    all_indices = list(range(len(training_set)))
    rnd.shuffle(all_indices)
    grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]

    # Load data.
    images, contexts = zip(*[training_set[i] for i in grid_indices])
    return (gw, gh), np.stack(images), contexts

from itertools import islice
def split_list(lst, split_sizes):
    """
    Splits a list into chunks based on split_sizes.

    Parameters:
    - lst (list): The list to be split.
    - split_sizes (list or int): Sizes of the chunks to split the list into. 
                                 If it's an integer, the list will be divided into chunks of this size.
                                 If it's a list of integers, the list will be divided into chunks of varying sizes specified by the list.

    Returns:
    - list of lists: The split list.
    """
    if isinstance(split_sizes, int):
        # If split_sizes is an integer, create a list of sizes to split the list evenly, except the last chunk which may be smaller.
        split_sizes = [split_sizes] * (len(lst) // split_sizes) + ([len(lst) % split_sizes] if len(lst) % split_sizes != 0 else [])
    it = iter(lst)
    return [list(islice(it, size)) for size in split_sizes]

#----------------------------------------------------------------------------
# Helper methods

def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    N, C, H, W = img.shape
    expected = gh * gw

    if N < expected:
        pad = expected - N
        print(f"[save_image_grid] Padding with {pad} black images to reach {expected}")
        pad_img = np.zeros((pad, C, H, W), dtype=np.uint8)
        img = np.concatenate([img, pad_img], axis=0)
    elif N > expected:
        print(f"[save_image_grid] Trimming {N - expected} extra images")
        img = img[:expected]

    img = img.reshape(gh, gw, C, H, W)
    img = img.transpose(0, 3, 1, 4, 2)  # (gh, H, gw, W, C)
    img = img.reshape(gh * H, gw * W, C)

    from PIL import Image
    assert C in [1, 3], f"Unsupported channel count: {C}"
    if C == 1:
        Image.fromarray(img[:, :, 0], 'L').save(fname)
    else:
        Image.fromarray(img, 'RGB').save(fname)

def save_data(data, fname):
    with open(fname, 'wb') as f:
        pickle.dump(data, f)

def save_fp16_pt(pt, fname, dtype=torch.float16):
    """
    Save a PyTorch state dictionary to disk, converting all tensors to the specified dtype.
    Handles nested dictionaries and None values gracefully.

    Args:
        pt (dict): State dictionary to save.
        fname (str): File path to save the state dictionary.
        dtype (torch.dtype): Desired data type for tensors (default: torch.float16).
    """
    def convert_tensor(val):
        if torch.is_tensor(val):
            return val.to(dtype=dtype)
        return val

    def convert_dict(d):
        if d is None:
            return None
        if isinstance(d, dict):
            return {k: convert_dict(v) for k, v in d.items()}
        return convert_tensor(d)

    pt_converted = {key: convert_dict(value) for key, value in pt.items()}
    torch.save(pt_converted, fname)


def append_line(jsonl_line, fname):
    with open(fname, 'at') as f:
        f.write(jsonl_line + '\n')


import contextlib
@contextlib.contextmanager
def fsdp_sync(module, sync):
    assert isinstance(module, torch.nn.Module)
    if sync:
        yield
    else:
        with module.no_sync():
            yield


def get_rank():
    return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
        
# def move_model_to_device(model, device):
#     if isinstance(model, torch.nn.Module):
#         if model.device != device:
#             model.to(device)
#     elif isinstance(model, list):
#         for m in model:
#             move_model_to_device(m, device)
#     elif isinstance(model, dict):   
#         for k, v in model.items():
#             model[k] = move_model_to_device(v, device)
#     return model

import contextlib
@contextlib.contextmanager
def temporary_model_on_device(model, device="cuda"):
    num_nodes = int(os.environ['WORLD_SIZE']) // int(os.environ['LOCAL_WORLD_SIZE'])
    if 0: #num_nodes > 1:
        yield model
    else:
        """Temporarily move model to a device and move it back after use (if needed)."""
        target_device = torch.device(device)
        original_device = model.device

        moved = original_device != target_device
        if moved:
            model.to(target_device)
            torch.cuda.empty_cache()

        try:
            yield model
        finally:
            if moved:
                model.to(original_device)
                gc.collect()
                if original_device.type == "cuda":
                    torch.cuda.empty_cache()


    
from datasets import load_dataset, load_from_disk
from filelock import FileLock
def load_dataset_distributed_safe(
    file_path: str,
    actual_datapart: int | None = None,
    dataset_name: str = "processed_dataset",
    dataset_format: str = "parquet",
    split: str = "train",
    cache_latents_dir: str = None,
):
    """
    Safely loads a [NAME] [NAME] dataset across multiple FSDP ranks.
    Only rank 0 writes to disk. All ranks load from disk.
    """
    import re
    if actual_datapart is None:
        match = re.match(r"train_(\d+)\.parquet", file_path)
        if match:
            actual_datapart = int(match.group(1)) 
    # if cache_latents_dir is None:
    #     # Resolve path to parent of 'SiD' directory
    #     BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
    #     cache_latents_dir = os.path.join(BASE_DIR, "data", "tmp")
    #     os.makedirs(cache_latents_dir, exist_ok=True)
    if cache_latents_dir is None:
        # Resolve path to parent of 'SiD' directory
        BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
        # Use the last part of run_dir to name the cache directory
        run_name = os.path.basename(os.path.normpath(run_dir))
        cache_latents_dir = os.path.join(BASE_DIR, "data", "tmp", run_name)

    os.makedirs(cache_latents_dir, exist_ok=True)

    assert torch.distributed.is_initialized(), "torch.distributed must be initialized"
    rank = get_rank()

    shared_path = os.path.join(cache_latents_dir, f"{dataset_name}_{actual_datapart}")
    lock_path = shared_path + ".lock"

    if rank == 0:
        with FileLock(lock_path):
            if not os.path.exists(shared_path):
                print(f"[Rank 0] Loading and saving dataset to {shared_path}...")
                hf_dataset = load_dataset(
                    path=dataset_format,
                    data_files={split: file_path},
                    split=split,
                )
                hf_dataset.save_to_disk(shared_path)
            else:
                print(f"[Rank 0] Dataset already exists at {shared_path}")
    torch.distributed.barrier()
    hf_dataset = load_from_disk(shared_path)
    torch.distributed.barrier()
    return hf_dataset

       
#----------------------------------------------------------------------------

def stream_data_sampler(
    datapart,
    text_image_pair_path,
    cache_latents_dir,
    #dataset_name,
    text_encoding_pipeline,
    vae,
    resolution,
    train_diffusiongan,
    uncond_embeds,
    uncond_attention_mask,
    batch_gpu,
    seed,
    data_loader_kwargs,
    precomputed_latents_path = None,
):
    """Define data samplers and iterators for training.
    
    Args:
        datapart: Current data partition number
        text_image_pair_path: Path to text-image pair dataset
        cache_latents_dir: Directory to cache latents
        precompute_latents: Whether to precompute latents
        dataset_name: Name of the dataset
        text_encoding_pipeline: Pipeline for text encoding
        vae: VAE model
        resolution: Image resolution
        train_diffusiongan: Whether training diffusion GAN
        uncond_embeds: Unconditional embeddings
        uncond_attention_mask: Unconditional attention mask
        batch_gpu: Batch size per GPU
        seed: Random seed
        data_loader_kwargs: Additional data loader arguments
        
    Returns:
        tuple: (dataset_latents_iterator, dataset_latents_iterator_text, hf_dataset)
    """

    
    dist.print0(f"Switching to new data part: {datapart}")
    actual_datapart = datapart % 124
    pt_file = f"train_{str(actual_datapart).zfill(3)}.parquet"
    file_path = os.path.join(text_image_pair_path, pt_file) 

    # Explicitly delete any previous dataset objects and run garbage collection to free GPU memory
    gc.collect()
    torch.cuda.empty_cache()


    # Only load hf_dataset if we need it (i.e., if we don't already have precomputed latents)
    hf_dataset = None
    if precomputed_latents_path is None or not os.path.exists(precomputed_latents_path):
        hf_dataset = load_dataset_distributed_safe(
            file_path=file_path,
            actual_datapart=actual_datapart,
            cache_latents_dir=cache_latents_dir,
        )

    if precomputed_latents_path is not None and not os.path.exists(precomputed_latents_path):
        # rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
        # rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
        # context_manager_vae = (
        #     temporary_model_on_device(vae, rank_device)
        #     if train_diffusiongan else
        #     contextlib.nullcontext(vae)
        # )
        # with temporary_model_on_device(text_encoding_pipeline, rank_device):
        #     with context_manager_vae:
        process_and_save_latents_with_pipeline_multigpu(
            hf_dataset=hf_dataset,
            text_encoding_pipeline=text_encoding_pipeline,
            vae=vae,
            latents_path=precomputed_latents_path,
            resolution=resolution,
            world_size=dist.get_world_size(),
            rank=get_rank(),
            train_diffusiongan=train_diffusiongan,
            batch_size=int(batch_gpu*2.5) if train_diffusiongan else batch_gpu*12,
        )
        # All ranks wait for the final file before proceeding
        torch.distributed.barrier()
        gc.collect()
        torch.cuda.empty_cache()

    dataset_latents_iterator = None
    dataset_latents_obj = None
    dataset_latents_sampler = None

    if train_diffusiongan:
        dataset_latents_obj = Text2ImageDataset(
            hf_dataset=hf_dataset,
            resolution=resolution,
            precomputed_latents_path=precomputed_latents_path,
        )
        dataset_latents_sampler = misc.InfiniteSampler(
            dataset=dataset_latents_obj,
            rank=get_rank(),
            num_replicas=dist.get_world_size(),
            seed=seed
        )
        dataset_latents_iterator = iter(torch.utils.data.DataLoader(
            dataset=dataset_latents_obj,
            sampler=dataset_latents_sampler,
            batch_size=batch_gpu,
            **data_loader_kwargs
        ))

    dataset_latents_obj_text = TextDataset(
        hf_dataset=hf_dataset,
        resolution=resolution,
        precomputed_latents_path=precomputed_latents_path,
    )
    dataset_latents_sampler_text = misc.InfiniteSampler(
        dataset=dataset_latents_obj_text,
        rank=get_rank(),
        num_replicas=dist.get_world_size(),
        seed=seed
    )
    dataset_latents_iterator_text = iter(torch.utils.data.DataLoader(
        dataset=dataset_latents_obj_text,
        sampler=dataset_latents_sampler_text,
        batch_size=batch_gpu,
        **data_loader_kwargs
    ))

    dist.print0(f"Loaded dataset from: {file_path}, dataset size: {len(dataset_latents_obj_text)}")
    dist.print0(f"Switched to dataset part: {datapart}")
    torch.distributed.barrier()

    # Try to free up memory from any temporary objects
    del dataset_latents_obj
    del dataset_latents_sampler
    del dataset_latents_obj_text
    del dataset_latents_sampler_text
    gc.collect()
    torch.cuda.empty_cache()

    return dataset_latents_iterator, dataset_latents_iterator_text, hf_dataset

def load_checkpoint(resume_training, fake_score, fake_score_optimizer_kwargs, G, g_optimizer_kwargs, FSDP, fsdp_kwargs, dnnlib, dist):
    import torch
    from torch.distributed.fsdp import StateDictType, FullStateDictConfig, FullOptimStateDictConfig

    def convert_to_fp32(state_dict):
        for k, v in state_dict.items():
            if isinstance(v, torch.Tensor) and (v.dtype == torch.float16 or v.dtype == torch.bfloat16):
                state_dict[k] = v.to(dtype=torch.float32)
        return state_dict

    def convert_optimizer_state_to_fp32(opt_state):
        for state in opt_state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor) and (v.dtype == torch.float16 or v.dtype == torch.bfloat16):
                    state[k] = v.to(dtype=torch.float32)

    if resume_training is None:
        fake_score_fsdp = FSDP(fake_score, **fsdp_kwargs)
        G_fsdp = FSDP(G, **fsdp_kwargs)

        fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score_fsdp.parameters(), **fake_score_optimizer_kwargs)
        g_optimizer = dnnlib.util.construct_class_by_name(params=G_fsdp.parameters(), **g_optimizer_kwargs)
    else:
        dist.print0('checkpoint path:', resume_training)
        model_checkpoint = torch.load(resume_training, map_location=torch.device('cpu'), weights_only=False)

        if model_checkpoint.get('fake_score_state') is None:
            resume_training_fake = resume_training + '_fake'
            model_checkpoint = torch.load(resume_training_fake, map_location=torch.device('cpu'), weights_only=False)
        fake_score.load_state_dict(convert_to_fp32(model_checkpoint['fake_score_state']))
        fake_score_fsdp = FSDP(fake_score, **fsdp_kwargs)

        if model_checkpoint.get('fake_score_optimizer_state') is None:
            resume_training_fake_optim = resume_training + '_fake_optim'
            model_checkpoint = torch.load(resume_training_fake_optim, map_location=torch.device('cpu'), weights_only=False)

        FSDP.set_state_dict_type(
            fake_score_fsdp,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(rank0_only=False),
            FullOptimStateDictConfig(rank0_only=False),
        )
        fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score_fsdp.parameters(), **fake_score_optimizer_kwargs)
        optim_state_dict = FSDP.optim_state_dict_to_load(
            fake_score_fsdp, fake_score_optimizer, model_checkpoint['fake_score_optimizer_state']
        )
        fake_score_optimizer.load_state_dict(optim_state_dict)
        convert_optimizer_state_to_fp32(fake_score_optimizer.state)
        del optim_state_dict

        if model_checkpoint.get('G_state') is None:
            resume_training_G = resume_training + '_G'
            model_checkpoint = torch.load(resume_training_G, map_location=torch.device('cpu'), weights_only=False)
        G.load_state_dict(convert_to_fp32(model_checkpoint['G_state']))
        G_fsdp = FSDP(G, **fsdp_kwargs)

        if model_checkpoint.get('g_optimizer_state') is None:
            resume_training_G_optim = resume_training + '_G_optim'
            model_checkpoint = torch.load(resume_training_G_optim, map_location=torch.device('cpu'), weights_only=False)

        FSDP.set_state_dict_type(
            G_fsdp,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(rank0_only=False),
            FullOptimStateDictConfig(rank0_only=False),
        )
        g_optimizer = dnnlib.util.construct_class_by_name(params=G_fsdp.parameters(), **g_optimizer_kwargs)
        optim_state_dict = FSDP.optim_state_dict_to_load(
            G_fsdp, g_optimizer, model_checkpoint['g_optimizer_state']
        )
        g_optimizer.load_state_dict(optim_state_dict)
        convert_optimizer_state_to_fp32(g_optimizer.state)
        del optim_state_dict

        del model_checkpoint
        torch.distributed.barrier()
        dist.print0('Loading checkpoint completed')

    return fake_score_fsdp, fake_score_optimizer, G_fsdp, g_optimizer

def training_loop(
    run_dir             = '.',      # Output directory.
    dataset_kwargs      = {},       # Options for training set.
    data_loader_kwargs  = {},       # Options for torch.utils.data.DataLoader.
    network_kwargs      = {},       # Options for model and preconditioning.
    fake_score_optimizer_kwargs   = {},       # Options for fake score network optimizer.
    g_optimizer_kwargs    = {},     # Options for generator optimizer.
    seed                = 0,        # Global random seed.
    batch_size          = 512,      # Total batch size for one training iteration.
    batch_gpu           = None,     # Limit batch size per GPU, None = no limit.
    total_kimg          = 2000,   # Training duration, measured in thousands of training images.
    ema_halflife_kimg   = 0,      # Half-life of the exponential moving average (EMA) of model weights.
    loss_scaling        = 1,        # Loss scaling factor for reducing FP16 under/overflows.
    loss_scaling_G      = 100,       # Loss scaling factor of G for reducing FP16 under/overflows.
    kimg_per_tick       = 2,       # Interval of progress prints.
    snapshot_ticks      = 25,       # How often to save network snapshots, None = disable.
    state_dump_ticks    = 25,      # How often to dump training state, None = disable.
    resume_training     = None,     # Resume training from the given network snapshot.
    resume_kimg         = 0,        # Start from the given training progress.
    alpha               = 1,         # loss = L2-alpha*L1
    cudnn_benchmark     = True,     # Enable torch.backends.cudnn.benchmark?
    device              = torch.device('cuda'),
    metrics             = None,
    init_timestep       = 999,
    metric_pt_path      = None,
    metric_open_clip_path            = 'clipvitg14.pkl',
    metric_clip_path                 = None,
    pretrained_model_name_or_path    = "Efficient-Large-Model/Sana_600M_512px_diffusers",
    pretrained_vae_model_name_or_path= "Efficient-Large-Model/Sana_600M_512px_diffusers",
    dataset_prompt_text_kwargs       = {},
    cfg_train_fake                   = 4.5,
    cfg_eval_fake                    = 4.5,
    cfg_eval_real                    = 4.5,
    num_steps                        = 4,
    resolution                       = None,
    loss_scaling_D                   = 1,
    loss_scaling_G_gan               = 1,
    dataset_latents_kwargs           = None,
    sid_model                        = None,
    pooling_type                     = 'spatial',
    text_device                      = torch.device('cuda'),
    vae_device                       = torch.device('cuda'),
    cpu_offload                      = False,
    time_scale                       = 1,
    gradient_checkpointing           = True,
    text_image_pair_path             = None,
    save_best_and_last               = True,
    weighting_scheme                 = "1_minus_sigma",  # Weighting scheme for loss computation
    noise_type                       = "fresh",          # Noise type: "fresh", "fixed", "ddim"
    train_diffusiongan               = False,
    precompute_latents               = False, 
    dataset_name           = "processed_dataset",
):
    
    
    if text_image_pair_path is not None:
        precompute_latents = True
        dist.print0("precompute latents is set to True")
        #precompute latents as of now only supports text-image pair dataset
    #precompute_latents = False
    #BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))

    BASE_DIR = os.environ.get("IRISCTL_SHARED_ARTIFACT_DIR")
    if BASE_DIR is None or BASE_DIR == '':
        print('>'*100, 'no shared dir env found')
        BASE_DIR = "/mnt/shared"

    # Use the last part of run_dir to name the cache directory
    run_name = os.path.basename(os.path.normpath(run_dir))
    cache_latents_dir = os.path.join(BASE_DIR, "data", "tmp", run_name)
    os.makedirs(cache_latents_dir, exist_ok=True)

    if save_best_and_last:
        if get_rank() == 0:
            previous_pt_filename = None #for saving the last checkpoint
            best_fid = float('inf')
            best_clip_score = float('-inf')
            current_fid = float('inf')
            current_clip_score = float('-inf')
            best_fid_step = None
            best_clip_step = None
    

    num_steps_eval = num_steps
    guidance_scale = 0
 
    use_fsdp = True
    use_ddp = not use_fsdp
    
    true_score_fsdp = True
    
    if ema_halflife_kimg==0:
        ema_device=device
    
    if resume_training is not None:
        seed = (seed+resume_kimg)% (1 << 31)

    cur_nimg = resume_kimg * 1000
        
    use_context_dropout_train_fake = False


    use_train_eval = False #whether switching between train and eval modes for G and fake score net
    use_grad_turn_on_off=True #we can keep train modes for both G and fake score net, but turning on and off gradients instead
     

   
    if network_kwargs.use_fp16:
        dtype= torch.float16
    else:
        dtype=torch.bfloat16 if network_kwargs.use_bf16 else torch.float32
        
    use_autocast = True if dtype==torch.float32 else False
    dtype_autocast=torch.bfloat16 if network_kwargs.autocast_bf16 else torch.float16

    
    fsdp_kwargs = {
        "cpu_offload": CPUOffload(offload_params=cpu_offload),
        "auto_wrap_policy": size_based_auto_wrap_policy,
        'device_id': torch.cuda.current_device(), 
    }
    if fsdp_kwargs["cpu_offload"].offload_params:
        fsdp_kwargs["sync_module_states"] = False
    else:
        fsdp_kwargs["sync_module_states"] = True
    
    
    if use_autocast and dtype_autocast==torch.float16:
        scaler =ShardedGradScaler()
        scaler_G = ShardedGradScaler()
    
    if get_rank() == 0:
        vae, dit, noise_scheduler, text_encoding_pipeline = load_dit(
            pretrained_model_name_or_path=pretrained_model_name_or_path, 
            weight_dtype=dtype, 
            num_steps=num_steps,
            train_diffusiongan=train_diffusiongan,
            device=device,
        )
    
    # Ensure rank 0 finishes loading before others continue
    torch.distributed.barrier()

    # Now all ranks load from cache
    if get_rank() != 0:
        vae, dit, noise_scheduler, text_encoding_pipeline = load_dit(
            pretrained_model_name_or_path=pretrained_model_name_or_path, 
            weight_dtype=dtype, 
            num_steps=num_steps,
            train_diffusiongan=train_diffusiongan,
            device=device,
        )            

    #text_device                      = 'cpu'
    #vae_device                       = 'cpu'
    if text_device == 'cpu':
        text_encoding_pipeline=text_encoding_pipeline.to(torch.device('cpu'))
    if vae_device == 'cpu':
        vae=vae.to(torch.device('cpu'))


    torch.distributed.barrier()

    negative_prompt = [""] *batch_gpu
    with temporary_model_on_device(text_encoding_pipeline, device):
        uncond_embeds, uncond_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
            prompt=negative_prompt,
            complex_human_instruction=False,
            do_classifier_free_guidance=False,
        )

    grid_size = None
    grid_z = None
    grid_c = None

    latent_img_channels = 32
    latent_resolution = resolution//32
    
    # Load dataset.
    dist.print0('Loading dataset...')
    # if precompute_latents:
    #     dataset_kwargs["text_encoding_pipeline"] = text_encoding_pipeline
    #     dataset_kwargs["uncond_embeds"] = uncond_embeds
    #     dataset_kwargs["uncond_attention_mask"] = uncond_attention_mask
    dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
    resolution=dataset_obj.resolution
    
    dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=get_rank(), num_replicas=dist.get_world_size(), seed=seed)
    dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))


    if use_fsdp or get_rank() == 0: 
        torch.manual_seed(2024)
        grid_size, images_real, contexts = setup_snapshot_image_grid(training_set=dataset_obj)
        grid_z = torch.randn([images_real.shape[0], latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
        grid_z = grid_z.split(batch_gpu)
        grid_c = split_list(contexts, batch_gpu)
        for c in grid_c:
            dist.print0(c)


    # Initialize.
    start_time = time.time()
    np.random.seed((seed * dist.get_world_size() + get_rank()) % (1 << 31))
    torch.manual_seed(np.random.randint(1 << 31))
    torch.backends.cudnn.benchmark = cudnn_benchmark
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
    

    iteration = 0
    
    # Select batch size per GPU.
    batch_gpu_total = batch_size // dist.get_world_size()
    if batch_gpu is None or batch_gpu > batch_gpu_total:
        batch_gpu = batch_gpu_total
    num_accumulation_rounds = batch_gpu_total // batch_gpu
    assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()

    

    if text_image_pair_path is None:
        # Data-free setting: use any prompts you wish (e.g., for distillation from text only)
        # if precompute_latents:
        #     dataset_prompt_text_kwargs["text_encoding_pipeline"] = text_encoding_pipeline
        #     dataset_prompt_text_kwargs["uncond_embeds"] = uncond_embeds
        #     dataset_prompt_text_kwargs["uncond_attention_mask"] = uncond_attention_mask
        dataset_prompt_text_obj = dnnlib.util.construct_class_by_name(**dataset_prompt_text_kwargs) # subclass of training.dataset.Dataset
        dataset_prompt_text_sampler = misc.InfiniteSampler(dataset=dataset_prompt_text_obj, rank=get_rank(), num_replicas=dist.get_world_size(), seed=seed)
        dataset_prompt_text_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_prompt_text_obj, sampler=dataset_prompt_text_sampler, batch_size=batch_gpu, **data_loader_kwargs))
    else:
        #if train_DiffusionGAN or Use_RealImagePrompts_Train_G:
        # Using prompts from real/synthetic text-image pairs
        # If GAN is used, we will also use the images to train the fake score net (part of it is reused as discriminator)
        # We follow SANA to use the following (synthetic) dataset:
        # [URL]
        
        total_images = 986000
        images_per_zip = total_images / 50
        datapart = cur_nimg // images_per_zip
        actual_datapart = datapart % 124
        if precompute_latents:
            precomputed_latents_path = os.path.join(
                    cache_latents_dir, f"{dataset_name}_{actual_datapart}.{'text_image_latents' if train_diffusiongan else 'text_latents'}.pt"
                )
        else:
            precomputed_latents_path = None
        #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
        rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
        context_manager_vae = (
            temporary_model_on_device(vae, rank_device)
            if train_diffusiongan and (precomputed_latents_path is not None and not os.path.exists(precomputed_latents_path)) else
            contextlib.nullcontext(vae)
        )
        context_manager_text_encoding_pipeline = (
            temporary_model_on_device(text_encoding_pipeline, rank_device)
            if precomputed_latents_path is not None and not os.path.exists(precomputed_latents_path) else
            contextlib.nullcontext(text_encoding_pipeline)
        )
        with context_manager_text_encoding_pipeline:
            with context_manager_vae:
                dataset_latents_iterator, dataset_latents_iterator_text, hf_dataset = stream_data_sampler(
                    datapart=int(datapart),
                    text_image_pair_path=text_image_pair_path,
                    cache_latents_dir=cache_latents_dir,
                    #dataset_name=processed_dataset_name,
                    text_encoding_pipeline=text_encoding_pipeline,
                    vae=vae,
                    resolution=resolution,
                    train_diffusiongan=train_diffusiongan,
                    uncond_embeds=uncond_embeds,
                    uncond_attention_mask=uncond_attention_mask,
                    batch_gpu=batch_gpu,
                    seed=seed,
                    data_loader_kwargs=data_loader_kwargs,
                    precomputed_latents_path=precomputed_latents_path,
                )

        dist.print0("sanity check for text-image pair dataset")
        
        if train_diffusiongan:
            batch = next(dataset_latents_iterator)
            dist.print0(batch["text"])
            latents = batch["image"].to(dtype=vae.dtype).to(device)


            #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
            rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
            context_manager_vae = (
                temporary_model_on_device(vae, rank_device)
                if train_diffusiongan and not precompute_latents else
                contextlib.nullcontext(vae)
            )
            with context_manager_vae:
                if not precompute_latents:
                    dist.print0(latents.shape)
                    with torch.no_grad():
                        latents = vae.encode(latents).latent*vae.config.scaling_factor #0.41407
                dist.print0(latents.shape)
            
                if get_rank() == 0:    
                    images_data = []
                    images = []
                    for z, c in zip(grid_z, grid_c):
                        batch = next(dataset_latents_iterator)
                        actual_batch_size = batch["image"].shape[0]
                        if len(c) > actual_batch_size:
                            dist.print0(f"Warning: batch too small. Requested {len(c)}, got {actual_batch_size}. Truncating.")
                            c = c[:actual_batch_size]
                        dist.print0(f"texts: {batch['text'][:len(c)]}")
                        image = batch["image"][:len(c)].to(dtype=vae.dtype).to(device)
                        #dist.print0(f"[DEBUG] image shape: {image.shape}, dtype: {image.dtype}")
                        if not precompute_latents:
                            with torch.no_grad():
                                encoded = vae.encode(image).latent * vae.config.scaling_factor
                            #dist.print0(f"[DEBUG] encoded shape (after vae.encode): {encoded.shape}, dtype: {encoded.dtype}")
                        else:
                            encoded = image
                            #dist.print0(f"[DEBUG] encoded shape (precomputed): {encoded.shape}, dtype: {encoded.dtype}")
                        decoded = vae.decode(encoded / vae.config.scaling_factor, return_dict=False)[0]
                        #dist.print0(f"[DEBUG] decoded shape: {decoded.shape}, dtype: {decoded.dtype}")
                        
                        images_data.append(image)
                        images.append(decoded)
                    images_data = torch.cat(images_data).to(torch.float32).cpu().numpy()
                    images = torch.cat(images).to(torch.float32).cpu().numpy()
                    #dist.print0(f"[DEBUG] images_data (cat) shape: {images_data.shape}, images (cat) shape: {images.shape}")
                    if not precompute_latents:
                        save_image_grid(img=images_data, fname=os.path.join(run_dir, 'true_mj.png'), drange=[-1,1], grid_size=grid_size)
                    save_image_grid(img=images, fname=os.path.join(run_dir, 'true_mj_vae_encode_decode.png'), drange=[-1,1], grid_size=grid_size)
                
             
           
    for _i in range(2):
        dist.print0(f'prompt for SiD:{_i}')
        if text_image_pair_path is not None:
            contexts = next(dataset_latents_iterator_text)["text"]
        else:
            _, contexts  = next(dataset_prompt_text_iterator)
        dist.print0(contexts)
    

    # true_score: The teacher model (frozen, used for distillation targets).
    true_score = dit
    true_score.eval().requires_grad_(False).to(device)     
    
    # fake_score: The auxiliary model to be trained to learn the score of the generator (we alternate between training the generator and the fake score net)
    # Deep copy the teacher model to initialize the fake_score network, ensuring it starts with the same weights.
    fake_score = copy.deepcopy(true_score).train().requires_grad_(True).to(device=device)

    # G: Generator model, either initialized from the teacher or loaded from a checkpoint.
    if sid_model is None:
        G = copy.deepcopy(true_score).train().requires_grad_(True).to(device=device)
    else:
        dist.print0(f'Loading network from "{sid_model}"...')
        with dnnlib.util.open_url(sid_model, verbose=(get_rank() == 0)) as f:
            # Load the EMA (exponential moving average) weights for the generator.
            # Note: while the name suggests the use of EMA, we actually disabled it when running FSDP, as in this code.
            loaded_net = pickle.load(f)
            if 'ema' in loaded_net:
                G = loaded_net['ema'].to(device=device, dtype=dtype)
            elif 'G' in loaded_net:
                G = loaded_net['G'].to(device=device, dtype=dtype)
            else:
                raise KeyError("Neither 'ema' nor 'G' found in the loaded checkpoint.")
            G.train().requires_grad_(True)
            del loaded_net
        dist.print0(f'Loaded network from "{sid_model}"...')
    

    # Barrier to synchronize all processes
    torch.distributed.barrier()

    # Wrap the true_score model with FSDP for distributed training.
    true_score = FSDP(true_score, **fsdp_kwargs)

    fake_score_fsdp, fake_score_optimizer, G_fsdp, g_optimizer = load_checkpoint(resume_training, fake_score, fake_score_optimizer_kwargs, G, g_optimizer_kwargs, FSDP, fsdp_kwargs, dnnlib, dist)
    
        

    if use_ddp:
        if use_train_eval:
            fake_score_ddp.eval().requires_grad_(False)
            G_ddp.eval().requires_grad_(False)
        else:
            fake_score_ddp.train().requires_grad_(True)
            G_ddp.train().requires_grad_(True)                                                  
    else:         
        G_ema = None                                  
        if use_train_eval:
            fake_score_fsdp.eval().requires_grad_(False)
            G_fsdp.eval().requires_grad_(False)
        else:
            fake_score_fsdp.train().requires_grad_(True)
            G_fsdp.train().requires_grad_(True)      
            
    torch.distributed.barrier()
    dist.print0(f'Training for {total_kimg} kimg...')
    dist.print0()
    
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    dist.update_progress(cur_nimg // 1000, total_kimg)
    stats_jsonl = None
    stats_metrics = dict()
    
    torch.distributed.barrier()
    
    if resume_training is None:
        if use_fsdp or get_rank() == 0:
            dist.print0('Exporting sample images...')
            
            # dist.print0(images_real[0])
            if get_rank() == 0:
                save_image_grid(img=images_real, fname=os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
                del images_real
        
            dist.print0(grid_c[0])
            #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
            rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
            with temporary_model_on_device(text_encoding_pipeline, rank_device):
                with temporary_model_on_device(vae, rank_device):
                    with torch.no_grad():
                        with torch.autocast(device_type="cuda",dtype=dtype_autocast,enabled=use_autocast):
                            images = [sid_dit_generate(dit=G_fsdp,latents=z,contexts=c,
                                        init_timesteps=init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
                                        noise_scheduler=noise_scheduler,
                                            text_encoding_pipeline = text_encoding_pipeline,
                                            resolution=resolution,dtype=dtype,return_images=True, vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=num_steps,
                                            time_scale=time_scale,noise_type=noise_type) for z, c in zip(grid_z, grid_c)]
                            
            if use_ddp:
                if G_ema.device!=ema_device:
                    G_ema=G_ema.to(ema_device)
            if get_rank() == 0:
                images = torch.cat(images).to(torch.float32).cpu().numpy()
                dist.print0(contexts[0])
                #print(images[0])
                save_image_grid(img=images, fname=os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
            del images
        
        
    
    # Barrier to synchronize all processes
    torch.distributed.barrier()

    
    if use_ddp:
        if use_train_eval:
            fake_score_ddp.train().requires_grad_(True)
            G_ddp.eval().requires_grad_(False)
            
        if use_grad_turn_on_off:
            G_ddp.requires_grad_(False)
            fake_score_ddp.requires_grad_(True)
            
        if gradient_checkpointing:
            G_ddp.module.enable_gradient_checkpointing()
            fake_score_ddp.module.enable_gradient_checkpointing()                                             
    else:                     
        #dist.print0(f"Iteration{iteration}  Prepare 1")
        if use_train_eval:
            fake_score_fsdp.train().requires_grad_(True)
            G_fsdp.eval().requires_grad_(False)
            
        #dist.print0(f"Iteration{iteration}  Prepare 2")
        if use_grad_turn_on_off:
            G_fsdp.requires_grad_(False)
            fake_score_fsdp.requires_grad_(True)
                
        #dist.print0(f"Iteration{iteration}  Prepare 3")
        if gradient_checkpointing:
            G_fsdp.enable_gradient_checkpointing()
            fake_score_fsdp.enable_gradient_checkpointing()
        #dist.print0(f"Iteration{iteration}  Prepare 4")                                             
    if use_ddp:
        if G_ema.device!=ema_device:
            G_ema=G_ema.to(ema_device)
        
    full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
    full_optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)

    with FSDP.state_dict_type(
        fake_score_fsdp,
        StateDictType.FULL_STATE_DICT,
        full_state_dict_config,
        full_optim_state_dict_config
    ):
        fake_score_state = fake_score_fsdp.state_dict()
        fake_score_optimizer_state_dict = FSDP.optim_state_dict(fake_score_fsdp, fake_score_optimizer)

    if not fake_score_optimizer_state_dict:
        dist.print0("Optimizer state dictionary is empty.")
    else:
        dist.print0("Optimizer state dictionary contains data.")



    torch.distributed.barrier() 
    torch.cuda.empty_cache()
    while True:
        if (train_diffusiongan or text_image_pair_path is not None) and cur_nimg > images_per_zip:
            new_datapart = int(cur_nimg // images_per_zip) 
            if datapart < new_datapart:
                datapart = new_datapart
                
                # Clear any pending operations
                torch.cuda.synchronize()
                
                # Delete old iterators and dataset
                if dataset_latents_iterator is not None:
                    del dataset_latents_iterator
                if dataset_latents_iterator_text is not None:
                    del dataset_latents_iterator_text
                if hf_dataset is not None:
                    hf_dataset.cleanup_cache_files()
                    hf_dataset.set_format(None)
                    del hf_dataset
                
                # Force garbage collection
                gc.collect()
                torch.cuda.empty_cache()
                
                if precompute_latents:
                    precomputed_latents_path = os.path.join(
                            cache_latents_dir, f"{dataset_name}_{datapart}.{'text_image_latents' if train_diffusiongan else 'text_latents'}.pt"
                        )
                else:
                    precomputed_latents_path = None
                #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
                rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
                context_manager_vae = (
                    temporary_model_on_device(vae, rank_device)
                    if train_diffusiongan and (precomputed_latents_path is not None and not os.path.exists(precomputed_latents_path)) else
                    contextlib.nullcontext(vae)
                )
                context_manager_text_encoding_pipeline = (
                    temporary_model_on_device(text_encoding_pipeline, rank_device)
                    if precomputed_latents_path is not None and not os.path.exists(precomputed_latents_path) else
                    contextlib.nullcontext(text_encoding_pipeline)
                )
                with context_manager_text_encoding_pipeline:
                    with context_manager_vae:
                        dataset_latents_iterator, dataset_latents_iterator_text, hf_dataset = stream_data_sampler(
                            datapart=datapart,
                            text_image_pair_path=text_image_pair_path,
                            cache_latents_dir=cache_latents_dir,
                            #dataset_name=processed_dataset_name,
                            text_encoding_pipeline=text_encoding_pipeline,
                            vae=vae,
                            resolution=resolution,
                            train_diffusiongan=train_diffusiongan,
                            uncond_embeds=uncond_embeds,
                            uncond_attention_mask=uncond_attention_mask,
                            batch_gpu=batch_gpu,
                            seed=seed,
                            data_loader_kwargs=data_loader_kwargs,
                            precomputed_latents_path=precomputed_latents_path,
                        )
                dist.print0(f"Switched to dataset part: {datapart}, cur_nimg: {cur_nimg}")
                
                # Ensure all processes are synchronized
                torch.distributed.barrier()
                
                # Clear memory again after loading
                gc.collect()
                torch.cuda.empty_cache()
                

                
        num_steps_random = torch.randint(1, num_steps+1, (1,)).item() 
        
        # Ensure all ranks use the same random number of steps per iteration
        if get_rank() == 0:
            num_steps_random_tensor = torch.randint(1, num_steps + 1, (1,), device="cuda", dtype=torch.int)
        else:
            num_steps_random_tensor = torch.empty(1, device="cuda", dtype=torch.int)
        torch.distributed.broadcast(num_steps_random_tensor, src=0)
        num_steps_random = num_steps_random_tensor.item()
        torch.distributed.barrier()
        
        #dist.print0(f"Iteration{iteration}  num_steps_random: {num_steps_random}")
        
        #torch.cuda.empty_cache()
        iteration =iteration+1
        #torch.distributed.barrier() 
        # We only wrap true score with FSDP in FSDP and true_score_fsdp mode 
       
        if use_ddp:           
            if fake_score_ddp.device != device:
                fake_score_ddp.to(device=device)
        

        if use_ddp:
            if use_grad_turn_on_off:
                G_ddp.train()
                fake_score_ddp.train()
                G_ddp.requires_grad_(False)
                fake_score_ddp.requires_grad_(True)
                
        else:
            
            if use_grad_turn_on_off:
                G_fsdp.train()
                fake_score_fsdp.train()
                G_fsdp.requires_grad_(False)
                fake_score_fsdp.requires_grad_(True)
                
        fake_score_optimizer.zero_grad(set_to_none=True)        

        Len_noise=0

        #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
        rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
        context_manager_vae = (
            temporary_model_on_device(vae, rank_device)
            if train_diffusiongan and not precompute_latents else
            contextlib.nullcontext(vae)
        )
        with context_manager_vae:
            for round_idx in range(num_accumulation_rounds):
                with torch.autocast(device_type="cuda",dtype=dtype_autocast,enabled=use_autocast):
                    #dist.print0(f"Iteration{iteration}  Prepare 7")
                    if train_diffusiongan:
                        if 0: # use_lmdb:
                            images_real,contexts_real = next(dataset_latents_iterator)
                            images_real=images_real.to(device) #.to(dtype)
                            images_real,contexts_real = next(dataset_latents_iterator)
                            images_real=images_real.to(device) #.to(dtype)
                        else:
                            #
                            batch = next(dataset_latents_iterator)
                            images_real = batch["image"].to(dtype=torch.float32, device=device)
                            contexts_real = batch["text"]
                            if not precompute_latents:
                                with torch.no_grad():
                                    images_real = vae.encode(images_real).latent*vae.config.scaling_factor
                                images_real=images_real.to(device)
                                prompt_embeds_real = None
                                prompt_attention_mask_real = None
                            else:
                                prompt_embeds_real = batch["prompt_embeds"].to(device=device)
                                prompt_attention_mask_real = batch["prompt_attention_mask"].to(device=device)
                        batch= next(dataset_latents_iterator_text)
                        contexts = batch["text"]
                        if precompute_latents:
                            prompt_embeds = batch["prompt_embeds"].to(device=device)
                            prompt_attention_mask = batch["prompt_attention_mask"].to(device=device)
                        else:
                            prompt_embeds = None
                            prompt_attention_mask = None
                            
                    else:
                        if text_image_pair_path is not None:
                            batch = next(dataset_latents_iterator_text)
                            contexts = batch["text"]
                            prompt_embeds = batch["prompt_embeds"].to(device=device) if precompute_latents else None
                            prompt_attention_mask = batch["prompt_attention_mask"].to(device=device) if precompute_latents else None
                        else:
                            _, contexts  = next(dataset_prompt_text_iterator)
                            

                    if use_context_dropout_train_fake:
                        bool_tensor = torch.rand(batch_gpu) < 0.1
                        if train_diffusiongan:
                            contexts_real = ["" if flag else caption for flag, caption in zip(bool_tensor.tolist(), contexts_real)]
                            contexts = ["" if flag else caption for flag, caption in zip(bool_tensor.tolist(), contexts)]
                        else:
                            contexts = ["" if flag else caption for flag, caption in zip(bool_tensor.tolist(), contexts)]

                    # initialize latents z and noise
                    z = torch.randn([batch_gpu, latent_img_channels, latent_resolution, latent_resolution], device=device)
                    noise = torch.randn_like(z)

                    init_timesteps = init_timestep * torch.ones((batch_gpu,), device=device, dtype=torch.long)

                    # if use_ddp:
                    #     sync_context = misc.ddp_sync
                    #     G_fsdp=G_ddp
                    # else:
                    #     sync_context = fsdp_sync

                
                            
                    #dist.print0(f"Iteration{iteration}  Start FSDP computation 0")
                    with fsdp_sync(fake_score_fsdp, (round_idx == num_accumulation_rounds - 1)):
                        with torch.no_grad():
                            images = sid_dit_generate(
                                dit=G_fsdp,
                                latents=z,
                                contexts=contexts,
                                init_timesteps=init_timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                return_images=False,
                                vae=None,
                                num_steps=num_steps,
                                train_sampler=False,
                                num_steps_eval=num_steps_random,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                noise_type=noise_type,
                                prompt_embeds=prompt_embeds,
                                prompt_attention_mask=prompt_attention_mask,
                            )
                                
                            
                            
                        logit_mean = torch.log(torch.tensor(2.0))
                        logit_std = 1.6
                        timesteps = torch.nn.functional.sigmoid(
                                torch.normal(mean=logit_mean, std=logit_std, size=(batch_gpu,), device=device)
                            )
                        
                        if train_diffusiongan:
                            target = noise-images.detach()
                            
                            
                            output, logit_fake = sid_dit_denoise(
                                dit=fake_score_fsdp,
                                images=images.detach(),
                                noise=noise,
                                contexts=contexts,
                                timesteps=timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                predict_x0=False,
                                guidance_scale=cfg_train_fake,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                return_flag='encoder_decoder',
                                prompt_embeds=prompt_embeds,
                                prompt_attention_mask=prompt_attention_mask,
                            )
                            
                        else:
                            
                            target = noise-images.detach()
                            output = sid_dit_denoise(
                                dit=fake_score_fsdp,
                                images=images.detach(),
                                noise=noise,
                                contexts=contexts,
                                timesteps=timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                predict_x0=False,
                                guidance_scale=cfg_train_fake,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                prompt_embeds=prompt_embeds,
                                prompt_attention_mask=prompt_attention_mask,
                            )


                        nan_mask = torch.isnan(output).flatten(start_dim=1).any(dim=1)
                        nan_mask = nan_mask | torch.isnan(target).flatten(start_dim=1).any(dim=1)

                        if train_diffusiongan:
                            nan_mask = nan_mask | torch.isnan(logit_fake).flatten(start_dim=1).any(dim=1)

                        # Check if there are any NaN values present
                        if nan_mask.any():
                            # Invert the nan_mask to get a mask of samples without NaNs
                            non_nan_mask = ~nan_mask
                            # Filter out samples with NaNs from y_real and y_fake
                            target = target[non_nan_mask]
                            output = output[non_nan_mask]
                            if train_diffusiongan:
                                logit_fake = logit_fake[non_nan_mask]
                            del non_nan_mask,nan_mask
                    
                        loss = torch.nn.functional.mse_loss(target, output, reduction="sum")
                        loss=loss.sum().mul(loss_scaling / batch_gpu_total)  

                        del images, contexts

                        if train_diffusiongan:
                            if noise.shape != images_real.shape:
                                dist.print0(f"Warning: noise shape {noise.shape} does not match images_real shape {images_real.shape}. Regenerating noise.")
                                noise = torch.randn_like(images_real)
                            
                            logit_real = sid_dit_denoise(
                                dit=fake_score_fsdp,
                                images=images_real,
                                noise=noise,
                                contexts=contexts_real,
                                timesteps=timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                predict_x0=False,
                                guidance_scale=cfg_train_fake,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                return_flag='encoder',
                                prompt_embeds=prompt_embeds_real,
                                prompt_attention_mask=prompt_attention_mask_real,
                            )
                            real_labels = torch.ones_like(logit_real)
                            fake_labels = torch.zeros_like(logit_fake)

                            sigmas=timesteps
                            #weight = 1/(sigmas**2)
                            #weight = weight.view(-1,1,1,1)
                            #weight = ((1-sigmas)**2).clip(min=0.00001).view(-1,1,1,1) 
                            weight = 1

                            bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
                            loss_real = bce_loss(logit_real.clamp(-10, 10), real_labels)
                            loss_fake = bce_loss(logit_fake.clamp(-10, 10), fake_labels)
                            loss_real = loss_real.view(loss_real.shape[0], -1).mean(dim=1, keepdim=True)
                            loss_fake = loss_fake.view(loss_fake.shape[0], -1).mean(dim=1, keepdim=True) 
                            loss_D = weight*(loss_real + loss_fake) / 2
                            loss_D = loss_D *images_real.shape[1]*images_real.shape[2]*images_real.shape[3]
                            loss_D = loss_D.sum().mul(loss_scaling_D / batch_gpu_total)
                            
                            del images_real, contexts_real, prompt_embeds_real, prompt_attention_mask_real


                        Len_noise = Len_noise+len(noise)
                        loss_fake_score_print=0
                        lossD_print=0
                        if len(noise) > 0:
                            if use_autocast and dtype_autocast==torch.float16:
                                if train_diffusiongan:
                                    scaler.scale(loss+loss_D).backward()
                                else:
                                    scaler.scale(loss).backward()
                            else:
                                if train_diffusiongan:
                                    (loss+loss_D).backward()
                                else:
                                    loss.backward()
                    

            loss_fake_score_print = loss.item()
            lossD_print = loss_D.item() if train_diffusiongan else 0

             
        training_stats.report('fake_score_Loss/loss', loss_fake_score_print)
        training_stats.report('D_Loss/loss', lossD_print)
            
        if use_ddp:
            fake_score_fsdp=fake_score_ddp.module
        
        if Len_noise>0:
            if use_autocast and dtype_autocast==torch.float16:
                scaler.unscale_(fake_score_optimizer)
                
            #for param in fake_score.parameters():
            for param in fake_score_fsdp.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
                    #torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
                    
            #torch.nn.utils.clip_grad_value_(fake_score_fsdp.parameters(), 1)
            
        if use_autocast and dtype_autocast==torch.float16:
            scaler.step(fake_score_optimizer)
            scaler.update()
        else:
            fake_score_optimizer.step()

        try:
            del z, noise,  loss
        except NameError as e:
            dist.print0(f"No such variable to delete: {e}")
        
        torch.cuda.empty_cache()
        
        fake_score_optimizer.zero_grad(set_to_none=True)   

        if use_train_eval:
            if use_ddp:
                G_ddp.train().requires_grad_(True)
                fake_score_ddp.eval().requires_grad_(False)
            else:
                G_fsdp.train().requires_grad_(True)
                fake_score_fsdp.eval().requires_grad_(False)
        #----------------------------------------------------------------------------------------------
        # [NAME] [NAME]-[NAME] [NAME] Network
        #if use_train_eval:              
        #    G_ddp.train().requires_grad_(True)

        if use_grad_turn_on_off:
            if use_ddp:
                fake_score_ddp.requires_grad_(False)
                G_ddp.requires_grad_(True)                                          
            else:                                         
                fake_score_fsdp.requires_grad_(False)
                G_fsdp.requires_grad_(True)
#                     if guidance_scale==100:
#                         fake_score_fsdp.module.apply(disable_bn_dropout)
#                         G_fsdp.module.apply(lambda m: restore_bn_dropout(m, name=str(m)))

            # for p in fake_score_ddp.module.parameters():
            #     p.requires_grad = False    
            # for p in G_ddp.module.parameters():
            #     p.requires_grad = True

        g_optimizer.zero_grad(set_to_none=True)
        Len_noise=0
        
        for round_idx in range(num_accumulation_rounds):
            with torch.autocast(device_type="cuda",dtype=dtype_autocast,enabled=use_autocast):
                
                if train_diffusiongan:
                    batch = next(dataset_latents_iterator_text)
                    contexts_real = batch["text"]
                    contexts=contexts_real
                    prompt_embeds_real = batch["prompt_embeds"].to(device=device) if precompute_latents else None
                    prompt_attention_mask_real = batch["prompt_attention_mask"].to(device=device) if precompute_latents else None
                    prompt_embeds = prompt_embeds_real
                    prompt_attention_mask = prompt_attention_mask_real
                else:
                    if text_image_pair_path is not None:
                        batch = next(dataset_latents_iterator_text)
                        contexts = batch["text"]
                        prompt_embeds = batch["prompt_embeds"].to(device=device) if precompute_latents else None
                        prompt_attention_mask = batch["prompt_attention_mask"].to(device=device) if precompute_latents else None
                    else:
                        _, contexts  = next(dataset_prompt_text_iterator)           
                
                z = torch.randn([batch_gpu, latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
                noise = torch.randn_like(z)

                # initialize timesteps
                init_timesteps = init_timestep * torch.ones((batch_gpu,), device=device, dtype=torch.long)
                #timesteps = torch.randint(20, 980, (batch_gpu,), device=device, dtype=torch.long)
                    
                logit_mean = torch.log(torch.tensor(2.0))
                logit_std = 1.6
            
                timesteps = torch.nn.functional.sigmoid(
                        torch.normal(mean=logit_mean, std=logit_std, size=(batch_gpu,), device=device)
                    )
                sigmas=timesteps

                
                
                with fsdp_sync(G_fsdp,False):
                    latent_model_input, t, prompt_embeds, prompt_attention_mask, latents = sid_dit_generate(
                        dit=G_fsdp,
                        latents=z,
                        contexts=contexts,
                        init_timesteps=init_timesteps,
                        noise_scheduler=noise_scheduler,
                        text_encoding_pipeline=text_encoding_pipeline,
                        resolution=resolution,
                        dtype=dtype,
                        return_images=False,
                        vae=None,
                        num_steps=num_steps,
                        train_sampler=True,
                        num_steps_eval=num_steps_random,
                        time_scale=time_scale,
                        uncond_embeds=uncond_embeds,
                        uncond_attention_mask=uncond_attention_mask,
                        noise_type=noise_type,
                        prompt_embeds=prompt_embeds,
                        prompt_attention_mask=prompt_attention_mask,
                    )
                    latent_model_input = latent_model_input.detach()
                    latents = latents.detach()

                with fsdp_sync(G_fsdp, (round_idx == num_accumulation_rounds - 1)):
                    
                    images = sid_dit_generate(
                        dit=G_fsdp,
                        latents=latents,
                        contexts=contexts,
                        init_timesteps=init_timesteps,
                        noise_scheduler=noise_scheduler,
                        text_encoding_pipeline=text_encoding_pipeline,
                        resolution=resolution,
                        dtype=dtype,
                        return_images=False,
                        vae=None,
                        num_steps=num_steps,
                        train_sampler=True,
                        num_steps_eval=num_steps_random,
                        time_scale=time_scale,
                        uncond_embeds=uncond_embeds,
                        uncond_attention_mask=uncond_attention_mask,
                        latent_model_input=latent_model_input,
                        prompt_embeds=prompt_embeds,
                        prompt_attention_mask=prompt_attention_mask,
                        t=t,   
                        noise_type=noise_type,
                    )                                        
                        

                    if train_diffusiongan:
                       
                        

                        y_fake,y_D = sid_dit_denoise(
                                dit=fake_score_fsdp,
                                images=images,
                                noise=noise,
                                contexts=contexts,
                                timesteps=timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                guidance_scale=cfg_eval_fake,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                return_flag='encoder_decoder',
                                prompt_embeds=prompt_embeds,
                                prompt_attention_mask=prompt_attention_mask,
                            )
                    else:
                        
                        y_fake = sid_dit_denoise(
                                dit=fake_score_fsdp,
                                images=images,
                                noise=noise,
                                contexts=contexts,
                                timesteps=timesteps,
                                noise_scheduler=noise_scheduler,
                                text_encoding_pipeline=text_encoding_pipeline,
                                resolution=resolution,
                                dtype=dtype,
                                guidance_scale=cfg_eval_fake,
                                time_scale=time_scale,
                                uncond_embeds=uncond_embeds,
                                uncond_attention_mask=uncond_attention_mask,
                                prompt_embeds=prompt_embeds,
                                prompt_attention_mask=prompt_attention_mask,
                            )
                    
                    y_real = sid_dit_denoise(
                            dit=true_score,
                            images=images,
                            noise=noise,
                            contexts=contexts,
                            timesteps=timesteps,
                            noise_scheduler=noise_scheduler,
                            text_encoding_pipeline=text_encoding_pipeline,
                            resolution=resolution,
                            dtype=dtype,
                            guidance_scale=cfg_eval_real,
                            time_scale=time_scale,
                            uncond_embeds=uncond_embeds,
                            uncond_attention_mask=uncond_attention_mask,
                            prompt_embeds=prompt_embeds,
                            prompt_attention_mask=prompt_attention_mask,
                        )
                    

                    nan_mask_images = torch.isnan(images).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
                    nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake
                    if train_diffusiongan:
                        nan_mask_y_D = torch.isnan(y_D).flatten(start_dim=1).any(dim=1)
                        nan_mask = nan_mask  | nan_mask_y_D


                    # Check if there are any NaN values present
                    if nan_mask.any():
                        # Invert the nan_mask to get a mask of samples without NaNs
                        non_nan_mask = ~nan_mask
                        # Filter out samples with NaNs from y_real and y_fake
                        images = images[non_nan_mask]
                        y_real = y_real[non_nan_mask]
                        y_fake = y_fake[non_nan_mask]
                        sigmas = sigmas[non_nan_mask]
                        if train_diffusiongan:
                            y_D = y_D[non_nan_mask]

                    with torch.no_grad():

                        if weighting_scheme == "sid_legacy":
                            if images is None or y_real is None:
                                raise ValueError("images and y_real required for sid_legacy weighting_scheme")
                            weight_factor = abs(images.to(torch.float32) - y_real.to(torch.float32)).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
                            scale_factor = 1/weight_factor
                        elif weighting_scheme == "snr_sqrt":
                            # SNR_sqrt
                            weight_factor = ((sigmas/(1-sigmas))).clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = 1/weight_factor
                        elif weighting_scheme == "snr":
                            # SNR
                            weight_factor = ((sigmas/(1-sigmas))**2).clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = 1/weight_factor
                        elif weighting_scheme == "1_over_sigma2":
                            # 1/sigma^2
                            weight_factor = ((sigmas)**2).clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = 1/weight_factor
                        elif weighting_scheme == "1_over_sigma":
                            # 1/sigma
                            weight_factor = sigmas.clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = 1/weight_factor
                        elif weighting_scheme == "1_minus_sigma_squared":
                            # (1-sigma)^2
                            #weight_factor = (1/(1-sigmas)**2).clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = ((1-sigmas)**2).view(-1,1,1,1)
                        elif weighting_scheme == "1_minus_sigma":
                            # (1-sigma) - default for SANA distillation
                            #weight_factor = (1/(1-sigmas)).clip(min=0.00001).view(-1,1,1,1)
                            scale_factor = (1-sigmas).view(-1,1,1,1)
                        else:
                            raise ValueError(f"Unknown weighting weighting_scheme: {weighting_scheme}. Available: sid_legacy, snr_sqrt, snr, 1_over_sigma2, 1_over_sigma, 1_minus_sigma_squared, 1_minus_sigma")

                    if alpha==1:
                        loss = (y_real - y_fake) * (y_fake - images) *scale_factor #/ weight_factor
                    else:
                        loss = (y_real - y_fake) * ((y_real - images) - alpha * (y_real - y_fake)) *scale_factor #/ weight_factor

                    loss=loss.sum().mul(loss_scaling_G / batch_gpu_total)
                    lossG_print = loss.item()
                    loss_gan=0
                    lossG_gan_print=0    

                    if train_diffusiongan:
                        y_D_labels = torch.ones_like(y_D)
                        bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
                        loss_gan = bce_loss(y_D.clamp(-10, 10),y_D_labels)
                        loss_gan = loss_gan.view(loss_gan.shape[0], -1).mean(dim=1, keepdim=True)*scale_factor.view(loss_gan.shape[0], -1)
                        loss_gan = loss_gan*images.shape[1]*images.shape[2]*images.shape[3]
                        loss_gan = loss_gan.sum().mul(loss_scaling_G_gan / batch_gpu_total)
                        lossG_gan_print = loss_gan.item()

                    if len(y_real) > 0:
                        Len_noise=Len_noise+len(y_real)
                        if train_diffusiongan:
                            #if cur_nimg>200*1000:
                            if cur_nimg>40*1000:
                                if use_autocast and dtype_autocast==torch.float16:
                                    scaler_G.scale(0.5*loss+0.5*loss_gan).backward()
                                else:
                                    (0.5*loss+0.5*loss_gan).backward()
                            else:
                                if use_autocast and dtype_autocast==torch.float16:
                                    scaler_G.scale(loss).backward()
                                else:
                                    loss.backward()
                        else:
                            if use_autocast and dtype_autocast==torch.float16:
                                scaler_G.scale(loss).backward()
                            else:
                                loss.backward()                            

        training_stats.report('G_Loss/loss', lossG_print)
        training_stats.report('G_gan_Loss/loss', lossG_gan_print)
        
        if use_ddp:
            G_fsdp=G_ddp.module
                                                
        
        if ((not train_diffusiongan) or cur_nimg>20*1000) and Len_noise > 0:
            if use_autocast and dtype_autocast==torch.float16:
                #if cur_nimg>100*1000 or sid_model is None:  
                scaler_G.unscale_(g_optimizer)
        
            for param in G_fsdp.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
            
            #torch.nn.utils.clip_grad_value_(G_fsdp.parameters(), 1) 
            
            if use_autocast and dtype_autocast==torch.float16:
                scaler_G.step(g_optimizer)
                scaler_G.update()
            else:
                g_optimizer.step()
        
        try:
            del z,noise,y_fake,y_real,loss,weight_factor, scale_factor 
        except NameError:
            pass
        torch.cuda.empty_cache()
        
        g_optimizer.zero_grad(set_to_none=True)    

        if use_train_eval:
            fake_score_fsdp.train().requires_grad_(True)
            G_fsdp.eval().requires_grad_(False)

        if use_grad_turn_on_off:
            G_fsdp.requires_grad_(False)
            fake_score_fsdp.requires_grad_(True)
            if guidance_scale==100:
                G_fsdp.module.apply(disable_bn_dropout)
                fake_score_fsdp.module.apply(lambda m: restore_bn_dropout(m, name=str(m)))

        if use_ddp:
            G_ema = G 
        
        # Perform maintenance tasks once per tick.
        cur_nimg += batch_size
        done = (cur_nimg >= total_kimg * 1000)

        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
        fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
        fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
        fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
        fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
        fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
        fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
        fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
        fields += [f"loss_fake_score {training_stats.report0('fake_score_Loss/loss', loss_fake_score_print):<6.2f}"]
        fields += [f"loss_G {training_stats.report0('G_Loss/loss', lossG_print):<6.2f}"]
        fields += [f"loss_G_gan {training_stats.report0('G_gan_Loss/loss', lossG_gan_print):<6.2f}"]
        fields += [f"loss_D {training_stats.report0('D_Loss/loss', lossD_print):<6.2f}"]
        torch.cuda.reset_peak_memory_stats()
        dist.print0(' '.join(fields))

        # Check for abort.
        if (not done) and dist.should_stop():
            done = True
            dist.print0()
            dist.print0('Aborting...')

        if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0 or cur_tick in [10,25,50,100]):

            if  use_ddp or not true_score_fsdp:
                if ema_device!=device:    
                    true_score=true_score.to(device=ema_device)
            
            
         
            G_ema = None

            #rank_device = torch.device(f"cuda:{get_rank() % dist.get_world_size()}")
            rank_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
            with temporary_model_on_device(text_encoding_pipeline, rank_device):
                with temporary_model_on_device(vae, rank_device):
                    
                    torch.distributed.barrier()
                    dist.print0('Exporting sample images...')
                    if (1 and get_rank() == 0) or (use_fsdp): # and dtype!=torch.float32:
                        # images = [sdxl_sample(unet=G_ema, latents=z, contexts=c, timesteps=init_timestep * torch.ones((1,), device=device, dtype=torch.long), 
                        #                       noise_scheduler=noise_scheduler, return_images=True) for z, c in zip(grid_z, grid_c)]
                        #for num_steps_eval in [1,2,3,4]:
                        List = [1,num_steps_eval] if num_steps_eval>1 else [1]
                        
                        for num_steps_sampler in List:
                        #if 1:
                            with torch.no_grad():
                                with torch.autocast(device_type="cuda",dtype=dtype_autocast,enabled=use_autocast):
                                    # Clear memory before generating images
                                    gc.collect()
                                    torch.cuda.empty_cache()
                                    
                                    images = []
                                    for z, c in zip(grid_z, grid_c):
                                        # Generate images in smaller batches to save memory
                                        batch_images = sid_dit_generate(
                                            dit=G_fsdp,
                                            latents=z,
                                            contexts=c,
                                            init_timesteps=init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
                                            noise_scheduler=noise_scheduler,
                                            text_encoding_pipeline=text_encoding_pipeline,
                                            resolution=resolution,
                                            dtype=dtype,
                                            return_images=True,
                                            vae=vae,
                                            num_steps=num_steps_sampler,
                                            train_sampler=False,
                                            num_steps_eval=num_steps_sampler,
                                            time_scale=time_scale,
                                            uncond_embeds=uncond_embeds,
                                            uncond_attention_mask=uncond_attention_mask,
                                            noise_type=noise_type,
                                        )
                                        images.append(batch_images)
                                        
                                        # Clean up intermediate tensors
                                        del batch_images
                                        torch.cuda.empty_cache()
                                    
                                if get_rank() == 0:
                                    # Concatenate and process images on rank 0
                                    images = torch.cat(images).to(torch.float32).cpu().numpy()
                                    save_image_grid(img=images, fname=os.path.join(run_dir, f'fakes_{alpha:03f}_{cur_nimg//1000:06d}_{num_steps_sampler:d}.png'), drange=[-1,1], grid_size=grid_size)
                                    
                                # Clean up after saving
                                del images
                                gc.collect()
                                torch.cuda.empty_cache()
                                torch.cuda.synchronize()
                        
                    if cur_tick>0:    
                        dist.print0('Evaluating metrics...')
                        dist.print0(metric_pt_path)
                        dist.print0(dist.get_world_size())
                        dist.print0(get_rank())
                        dist.print0('Evaluating metrics...')
                        #List = [num_steps] 
                        #for num_steps_sampler in List:
                        num_steps_sampler = num_steps
                        if metrics is not None:
                            for metric in metrics:
                                # Clear memory before each metric calculation
                                gc.collect()
                                torch.cuda.empty_cache()
                                
                                with torch.no_grad():
                                    with torch.autocast(device_type="cuda",dtype=dtype_autocast,enabled=use_autocast):
                                        result_dict = metric_main.calc_metric(
                                            metric=metric,
                                            metric_pt_path=metric_pt_path,
                                            metric_open_clip_path=metric_open_clip_path,
                                            metric_clip_path=metric_clip_path,
                                            G=partial(
                                                sid_dit_generate,
                                                dit=G_fsdp,
                                                noise_scheduler=noise_scheduler,
                                                text_encoding_pipeline=text_encoding_pipeline,
                                                resolution=resolution,
                                                dtype=dtype,
                                                return_images=True,
                                                vae=vae,
                                                num_steps=num_steps_sampler,
                                                train_sampler=False,
                                                num_steps_eval=num_steps_sampler,
                                                time_scale=time_scale,
                                                uncond_embeds=uncond_embeds,
                                                uncond_attention_mask=uncond_attention_mask,
                                                noise_type=noise_type,
                                            ),
                                            init_timestep=init_timestep,
                                            dataset_kwargs=dataset_kwargs,
                                            num_gpus=dist.get_world_size(),
                                            rank=get_rank(),
                                            local_rank=dist.get_local_rank(),
                                            device=device
                                        )
                                        current_fid = result_dict.results.fid10k_full
                                        current_clip_score = result_dict.results.open_clipscore_10k
                                        
                                # Clean up after each metric
                                
                                

                                if get_rank() == 0:
                                    dist.print0(result_dict.results)
                                    metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=f'fakes_{alpha:03f}_{cur_nimg//1000:06d}_{num_steps_sampler}step.png', alpha=alpha, num_steps_eval=num_steps_sampler)    
                                stats_metrics.update(result_dict.results)
                                del result_dict
                                gc.collect()
                                torch.cuda.empty_cache()
                                torch.cuda.synchronize()
                    if G_ema is not None:
                        if G_ema.device!=ema_device:
                            G_ema=G_ema.to(ema_device)
                            torch.cuda.empty_cache()
                    
                    full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                    full_optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
                    with FSDP.state_dict_type(
                        G_fsdp,
                        StateDictType.FULL_STATE_DICT,
                        full_state_dict_config,
                        full_optim_state_dict_config
                    ):
                        state_dict= G_fsdp.state_dict()
                    if get_rank() == 0:
                        if use_ddp:
                            data = dict(ema=G_ema)
                        else:
                            G_ema_fsdp = type(G)(**G.config).eval().requires_grad_(False)  
                            #G_ema_fsdp.load_state_dict(state_dict)
                            G_ema_fsdp.load_state_dict(state_dict, strict=False)
                            data = dict(ema = G_ema_fsdp)

                        for key, value in data.items():
                            if isinstance(value, torch.nn.Module):
                                with torch.no_grad():
                                    # 1) Remember the original device (likely GPU)
                                    original_device = next(value.parameters()).device

                                    # 2) Move the *original* module to CPU to avoid GPU overhead during deepcopy
                                    value.cpu()

                                    # 3) Deepcopy on CPU
                                    value_cpu_copy = copy.deepcopy(value).to(torch.float16)
                                    
                                    #value_cpu_copy = value.detach().clone().to(torch.float16)
                                    
                                    #value.detach().clone().to(torch.float16)

                                    # 4) Finalize the copy for inference (eval mode + no grad)
                                    value_cpu_copy.eval().requires_grad_(False)

                                    # 5) Move original back to GPU (or whatever device it was on)
                                    value.to(original_device)

                                    # 6) Store the CPU-side snapshot in data
                                    data[key] = value_cpu_copy

                            # Optionally, delete the local variable to free references.
                            del value
                        
                        if not save_best_and_last:
                            save_data(data=data, fname=os.path.join(run_dir, f'network-snapshot-{alpha:03f}-{cur_nimg//1000:06d}.pkl'))
                        else:
                            step_str = f'{alpha:03f}-{cur_nimg//1000:06d}'
                            checkpoint_path = os.path.join(run_dir, f'network-snapshot-{step_str}.pkl')

                            # Save best FID
                            if current_fid < best_fid:
                                # Save new checkpoint
                                save_data(data=data, fname=checkpoint_path)
                        
                                # Remove previous best FID checkpoint if it's not also best CLIP
                                if best_fid_step is not None and best_fid_step != best_clip_step:
                                    old_fid_path = os.path.join(run_dir, f'network-snapshot-{best_fid_step}.pkl')
                                    if os.path.exists(old_fid_path):
                                        try:
                                            os.remove(old_fid_path)
                                        except OSError as e:
                                            dist.print0(f"Error removing old FID checkpoint: {e}")
                        
                                # Update best FID
                                best_fid = current_fid
                                best_fid_step = step_str
                        
                            # Save best CLIP
                            if current_clip_score > best_clip_score:
                                # Save new checkpoint
                                save_data(data=data, fname=checkpoint_path)
                        
                                # Remove previous best CLIP checkpoint if it's not also best FID
                                if best_clip_step is not None and best_clip_step != best_fid_step:
                                    old_clip_path = os.path.join(run_dir, f'network-snapshot-{best_clip_step}.pkl')
                                    if os.path.exists(old_clip_path):
                                        try:
                                            os.remove(old_clip_path)
                                        except OSError as e:
                                            dist.print0(f"Error removing old CLIP checkpoint: {e}")
                        
                                # Update best CLIP
                                best_clip_score = current_clip_score
                                best_clip_step = step_str
                        del data
                        
        

        if use_ddp or not true_score_fsdp:    
            if true_score.device!=device:
                true_score=true_score.to(device=device) #,dtype=dtype)
        
        if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0:
            
            
            torch.cuda.empty_cache()
            torch.distributed.barrier()
            
            if use_ddp:
                if get_rank() == 0:
                    dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_G')
                    save_fp16_pt(pt=dict(G_state=G_ddp.module.state_dict() if use_ddp else G_state, g_optimizer_state=g_optimizer.state_dict() if use_ddp else g_optimizer_state), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_G'),dtype=torch.float16)
                    dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_fake')
                    save_fp16_pt(pt=dict(fake_score_state=fake_score_ddp.module.state_dict() if use_ddp else fake_score_state, fake_score_optimizer_state=fake_score_optimizer.state_dict() if use_ddp else fake_score_optimizer_state), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_fake'),dtype=torch.float16)
                    dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt')
                    save_fp16_pt(pt=dict(G_ema_state=G_ema.state_dict()), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'),dtype=torch.float16)
            else:
                torch.distributed.barrier()
                
                full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
                full_optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)

                with FSDP.state_dict_type(
                    G_fsdp,
                    StateDictType.FULL_STATE_DICT,
                    full_state_dict_config,
                    full_optim_state_dict_config
                ):
                    G_state = G_fsdp.state_dict()
                    g_optimizer_state_dict = FSDP.optim_state_dict(G_fsdp, g_optimizer)

                
                if get_rank() == 0:
                    dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_G')
                    save_fp16_pt(pt=dict(G_state= G_state), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_G'),dtype=torch.float16)
                    
                if get_rank()==(1%dist.get_world_size()):
                    print(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_G_optim')
                    save_fp16_pt(pt=dict(g_optimizer_state=g_optimizer_state_dict), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_G_optim'),dtype=torch.float16)
                
                torch.cuda.empty_cache()
                
                
                full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
                full_optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)

                with FSDP.state_dict_type(
                    fake_score_fsdp,
                    StateDictType.FULL_STATE_DICT,
                    full_state_dict_config,
                    full_optim_state_dict_config
                ):
                    fake_score_state = fake_score_fsdp.state_dict()
                    fake_score_optimizer_state_dict = FSDP.optim_state_dict(fake_score_fsdp, fake_score_optimizer)

                if not fake_score_optimizer_state_dict:
                    dist.print0("Optimizer state dictionary is empty.")
                else:
                    dist.print0("Optimizer state dictionary contains data.")

                                
                if get_rank() == (2%dist.get_world_size()):
                    print(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_fake')
                    save_fp16_pt(pt=dict(fake_score_state=fake_score_state), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_fake'),dtype=torch.float16)
                    
                if get_rank() == (3%dist.get_world_size()):
                    print(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt_fake_optim')
                    save_fp16_pt(pt=dict(fake_score_optimizer_state=fake_score_optimizer_state_dict), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt_fake_optim'),dtype=torch.float16)
                    
                #del fake_score_state, fake_score_optimizer_state_dict 
                torch.distributed.barrier()
                torch.cuda.empty_cache()


                if get_rank() == 0:
                    dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt')
                    save_fp16_pt(pt=dict(G_ema_state=None), #G_ema.state_dict()), 
                                            fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'),dtype=torch.float16)
                #del G_state
                
                if save_best_and_last is True and get_rank() == 0:
                    try:
                        if previous_pt_filename is not None:
                            #if os.path.exists(previous_pt_filename):
                            #    os.remove(previous_pt_filename)
                            # Define checkpoint suffixes
                            suffixes = ["", "_fake", "_fake_optim", "_G", "_G_optim"]
                            # Construct full paths
                            filenames = [f"{previous_pt_filename}{suffix}" for suffix in suffixes]
                            # Remove any existing checkpoint files
                            for filename in filenames:
                                if os.path.exists(filename):
                                    os.remove(filename)
                                    dist.print0(f"Removed: {filename}")
                                else:
                                    dist.print0(f"Not found (skipped): {filename}")

                    except OSError as e:
                        dist.print0(f"Error removing previous checkpoint: {e}")
        
                    previous_pt_filename = os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')
                
                torch.cuda.empty_cache()
                

            
        # Update logs.
        training_stats.default_collector.update()
        if get_rank() == 0:
            if stats_jsonl is None:
                append_line(jsonl_line=json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n', fname=os.path.join(run_dir, f'stats_{alpha:03f}.jsonl'))

        dist.update_progress(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    dist.print0('Exiting...')

#----------------------------------------------------------------------------