import dataclasses
import torch
import torch.distributed as dist

import os
DISTRL_DEBUG_NAN = os.environ.get("DISTRL_DEBUG_NAN", None)
DISTRL_DEBUG_ALIGN_SFT_RL = os.environ.get("DISTRL_DEBUG_ALIGN_SFT_RL", None)
DISTRL_DEBUG_ALIGN_SFT_RL_DOSFT = os.environ.get("DISTRL_DEBUG_ALIGN_SFT_RL_DOSFT", None)
DISTRL_DEBUG_ALIGN_SFT_RL_SFT_WITH_FAKE_RL_LOSS = os.environ.get("DISTRL_DEBUG_ALIGN_SFT_RL_SFT_WITH_FAKE_RL_LOSS", None)
DISTRL_DEBUG_ALIGN_SFT_RL_SFT_SDE_DRIFT = os.environ.get("DISTRL_DEBUG_ALIGN_SFT_RL_SFT_SDE_DRIFT", None)

import numpy as np
def get_random_indices(num_indices, sample_size):
    """Returns a random sample of indices from a larger list of indices.

    Args:
        num_indices (int): The total number of indices to choose from.
        sample_size (int): The number of indices to choose.

    Returns:
        A numpy array of `sample_size` randomly chosen indices.
    """
    return np.random.choice(num_indices, size=sample_size, replace=False)

@dataclasses.dataclass(frozen=False)
class TrainPolicyFuncData:
    """Data class to track policy training statistics."""
    tot_p_loss: float = 0
    tot_ratio: float = 0
    tot_kl: float = 0
    tot_grad_norm: float = 0
    tot_adv_mean: float = 0
    tot_adv_std: float = 0
    unclipped_ratio_values: list = dataclasses.field(default_factory=list)
    ratio_values: list = dataclasses.field(default_factory=list)
    adv_values: list = dataclasses.field(default_factory=list)

def _train_policy_func(
    args,
    state_dict,
    pipe,
    unet_copy,
    is_ddp,
    count,
    policy_steps,
    accelerator,
    tpfdata
):
    """Trains the policy function on collected data.

    Args:
        args: Arguments.
        state_dict: Dictionary containing collected experiences.
        pipe: Diffusion pipeline.
        unet_copy: Original UNet model for KL calculation.
        is_ddp: Whether the model is using DistributedDataParallel.
        count: Current training step.
        policy_steps: Total number of policy steps.
        accelerator: Accelerator object.
        tpfdata: Object to track training statistics.
        value_function: Optional value function for advantage computation.

    Returns:
        None, updates the policy model in-place.
    """

    with torch.no_grad():
        # Get batch indices
        if args.use_non_repeating_samples and "policy_indices" in state_dict:
            # Get the next batch of indices
            start_idx = state_dict["policy_index_position"]
            policy_indices_len = len(state_dict["policy_indices"])

            # Ensure we get exactly p_batch_size indices, using modulo if necessary
            indices = []
            for i in range(args.p_batch_size):
                idx = (start_idx + i) % policy_indices_len
                indices.append(state_dict["policy_indices"][idx])
            indices = torch.tensor(indices)

            # Update position
            state_dict["policy_index_position"] = (start_idx + args.p_batch_size) % policy_indices_len
        else:
            # Use random sampling
            indices = get_random_indices(state_dict["state"].shape[0], args.p_batch_size)

        # Get batch data
        batch_state = state_dict["state"][indices]  # [p_batch_size, num_samples, 4, 64, 64]
        batch_next_state = state_dict["next_state"][indices]  # [p_batch_size, num_samples, 4, 64, 64]
        batch_timestep = state_dict["timestep"][indices]  # [p_batch_size]
        batch_final_reward = state_dict["final_reward"][indices]  # [p_batch_size]
        batch_unconditional_prompt_embeds = state_dict["unconditional_prompt_embeds"][indices]  # [p_batch_size, num_samples, 77, 768]
        batch_guided_prompt_embeds = state_dict["guided_prompt_embeds"][indices]  # [p_batch_size, num_samples, 77, 768]
        batch_log_prob = state_dict["log_prob"][indices]  # [p_batch_size, num_samples]

    # Process each sample separately and collect log probabilities
    log_probs = []
    kl_regularizers = []

    # Get dimensions
    num_samples = batch_state.shape[1]

    # Process each sample in the batch
    for sample_idx in range(num_samples):
        # Extract data for this sample
        sample_state = batch_state[:, sample_idx]  # [p_batch_size, 4, 64, 64]
        sample_next_state = batch_next_state[:, sample_idx]  # [p_batch_size, 4, 64, 64]
        sample_unconditional = batch_unconditional_prompt_embeds[:, sample_idx]  # [p_batch_size, 77, 768]
        sample_guided = batch_guided_prompt_embeds[:, sample_idx]  # [p_batch_size, 77, 768]
        sample_prompt_embeds = torch.cat([sample_unconditional, sample_guided])  # [p_batch_size * 2, 77, 768]

        # Calculate log probability and KL for this sample
        sample_log_prob, sample_kl = pipe.forward_calculate_logprob(
            prompt_embeds=sample_prompt_embeds.cuda(),
            latents=sample_state.cuda(),
            next_latents=sample_next_state.cuda(),
            ts=batch_timestep.cuda(),
            unet_copy=unet_copy,
            num_inference_steps=args.num_inference_steps,
            is_ddp=is_ddp,
        )

        log_probs.append(sample_log_prob)
        kl_regularizers.append(sample_kl)

    # Stack log probabilities and KL regularizers
    log_prob = torch.stack(log_probs, dim=1)  # [p_batch_size, num_samples]
    kl_regularizer = torch.stack(kl_regularizers, dim=1)  # [p_batch_size, num_samples]

    with torch.no_grad():
        # Expand reward to match log_prob shape
        adv = batch_final_reward.cuda().reshape([args.p_batch_size, 1])

        # Normalize advantage by (reward - mean) / std when grpo_flag is True
        if args.grpo_flag != 0:
            adv_mean = adv.mean()
            adv_std = adv.std()

            # Synchronize mean and std across GPUs when using distributed training with grpo_flag=2
            if is_ddp:
                # For variance calculation, we need count, sum, and sum of squares
                batch_size = adv.numel()
                batch_size_tensor = torch.tensor([batch_size], device=accelerator.device)
                adv_sum_tensor = torch.tensor([adv.sum().item()], device=accelerator.device)
                adv_sq_sum_tensor = torch.tensor([(adv ** 2).sum().item()], device=accelerator.device)

                # All-reduce to get global statistics
                dist.all_reduce(batch_size_tensor, op=dist.ReduceOp.SUM)
                dist.all_reduce(adv_sum_tensor, op=dist.ReduceOp.SUM)
                dist.all_reduce(adv_sq_sum_tensor, op=dist.ReduceOp.SUM)

                # Calculate global mean and standard deviation correctly
                global_batch_size = batch_size_tensor.item()
                global_mean = adv_sum_tensor.item() / global_batch_size
                global_var = (adv_sq_sum_tensor.item() / global_batch_size) - (global_mean ** 2)
                global_std = torch.sqrt(torch.tensor(global_var) + 1e-8)

                # Use global statistics
                adv_mean = global_mean
                adv_std = global_std.item()

            adv = (adv - adv_mean) / (adv_std + 1e-8)  # Add small epsilon to avoid division by zero

    # Importance sampling ratio
    ratio = torch.exp(log_prob.sum(-1) - batch_log_prob.cuda().sum(-1))

    # Clip ratio to prevent extreme policy updates
    ratio = torch.clamp(ratio, 1.0 - args.ratio_clip, 1.0 + args.ratio_clip)

    # Store unclipped ratio values
    tpfdata.unclipped_ratio_values.extend(ratio.detach().cpu().flatten().tolist())

    # Store clipped ratio values
    tpfdata.ratio_values.extend(ratio.detach().cpu().flatten().tolist())

    # Store advantage values
    tpfdata.adv_values.extend(adv.detach().cpu().flatten().tolist())

    # policy loss, no minus for we want FID to get lower
    loss = (
        args.reward_weight
        * adv.detach().float()
        * ratio.float().reshape([args.p_batch_size, 1])
    ).mean()

    # Add KL regularization after warm-up
    if count > args.kl_warmup:
        loss += args.kl_weight * kl_regularizer.mean()

    # Scale by gradient accumulation steps
    loss = loss / (args.gradient_accumulation_steps)

    # Backpropagate
    accelerator.backward(loss)

    # Track statistics for logging
    tpfdata.tot_ratio += ratio.mean().item() / policy_steps
    tpfdata.tot_kl += kl_regularizer.mean().item() / policy_steps
    tpfdata.tot_p_loss += loss.item() / policy_steps

def _train_policy_func_eachimg(
    args,
    state_dict,
    pipe,
    unet_copy,
    is_ddp,
    count,
    policy_steps,
    accelerator,
    tpfdata
):
    """Trains the policy function on collected data.

    Args:
        args: Arguments.
        state_dict: Dictionary containing collected experiences.
        pipe: Diffusion pipeline.
        unet_copy: Original UNet model for KL calculation.
        is_ddp: Whether the model is using DistributedDataParallel.
        count: Current training step.
        policy_steps: Total number of policy steps.
        accelerator: Accelerator object.
        tpfdata: Object to track training statistics.

    Returns:
        None, updates the policy model in-place.
    """

    with torch.no_grad():
        # Get batch indices based on p_num_groups instead of p_batch_size
        if args.use_non_repeating_samples and "policy_indices" in state_dict:
            # Get the next batch of indices
            start_idx = state_dict["policy_index_position"]
            policy_indices_len = len(state_dict["policy_indices"])

            # Ensure we get exactly p_num_groups indices, using modulo if necessary
            indices = []
            for i in range(args.p_num_groups):
                idx = (start_idx + i) % policy_indices_len
                indices.append(state_dict["policy_indices"][idx])
            indices = torch.tensor(indices)

            # Update position
            state_dict["policy_index_position"] = (start_idx + args.p_num_groups) % policy_indices_len
        else:
            # Use random sampling
            indices = get_random_indices(state_dict["state"].shape[0], args.p_num_groups)

        # Get batch data
        batch_state = state_dict["state"][indices]  # [p_num_groups, num_samples, 4, 64, 64]
        batch_next_state = state_dict["next_state"][indices]  # [p_num_groups, num_samples, 4, 64, 64]
        batch_timestep = state_dict["timestep"][indices]  # [p_num_groups]
        batch_final_reward = state_dict["final_reward"][indices]  # [p_num_groups]
        batch_unconditional_prompt_embeds = state_dict["unconditional_prompt_embeds"][indices]  # [p_num_groups, num_samples, 77, 768]
        batch_guided_prompt_embeds = state_dict["guided_prompt_embeds"][indices]  # [p_num_groups, num_samples, 77, 768]
        batch_log_prob = state_dict["log_prob"][indices]  # [p_num_groups, num_samples]

    # Get dimensions
    p_num_groups = batch_state.shape[0]
    num_samples = batch_state.shape[1]

    # Reshape batch_timestep and batch_final_reward to [p_num_groups, num_samples]
    batch_timestep = batch_timestep.unsqueeze(1).expand(-1, num_samples)  # [p_num_groups, num_samples]
    batch_final_reward = batch_final_reward.unsqueeze(1).expand(-1, num_samples)  # [p_num_groups, num_samples]

    # Initialize trackers for statistics
    total_ratio = 0
    total_kl = 0
    total_p_loss = 0

    # Reshape all data to [p_num_groups * num_samples, ...]
    flat_state = batch_state.reshape(-1, *batch_state.shape[2:])  # [p_num_groups * num_samples, 4, 64, 64]
    flat_next_state = batch_next_state.reshape(-1, *batch_next_state.shape[2:])  # [p_num_groups * num_samples, 4, 64, 64]
    flat_timestep = batch_timestep.reshape(-1)  # [p_num_groups * num_samples]
    flat_final_reward = batch_final_reward.reshape(-1)  # [p_num_groups * num_samples]
    flat_unconditional = batch_unconditional_prompt_embeds.reshape(-1, *batch_unconditional_prompt_embeds.shape[2:])  # [p_num_groups * num_samples, 77, 768]
    flat_guided = batch_guided_prompt_embeds.reshape(-1, *batch_guided_prompt_embeds.shape[2:])  # [p_num_groups * num_samples, 77, 768]
    flat_log_prob = batch_log_prob.reshape(-1)  # [p_num_groups * num_samples]

    # Total number of flat samples
    total_flat_samples = p_num_groups * num_samples

    # Define a function to process a batch
    def process_batch(batch_flat_state, batch_flat_next_state, batch_flat_timestep,
                     batch_flat_final_reward, batch_flat_unconditional, batch_flat_guided,
                     batch_flat_log_prob, current_batch_size):
        # Create prompt embeds
        batch_flat_prompt_embeds = torch.cat([batch_flat_unconditional, batch_flat_guided])  # [current_batch_size * 2, 77, 768]

        # Calculate log probability and KL for this batch
        batch_log_prob_out, batch_kl = pipe.forward_calculate_logprob(
            prompt_embeds=batch_flat_prompt_embeds.cuda(),
            latents=batch_flat_state.cuda(),
            next_latents=batch_flat_next_state.cuda(),
            ts=batch_flat_timestep.cuda(),
            unet_copy=unet_copy,
            num_inference_steps=args.num_inference_steps,
            is_ddp=is_ddp,
        )

        with torch.no_grad():
            # Get batch's reward
            adv = batch_flat_final_reward.cuda().reshape([current_batch_size, 1])

            adv_before_norm = adv.mean().item()

            # Normalize advantage by (reward - mean) / std when grpo_flag is True
            if args.grpo_flag != 0:
                adv_mean = adv.mean()
                adv_std = adv.std()

                # Synchronize mean and std across GPUs when using distributed training with grpo_flag=2
                if is_ddp:
                    # For variance calculation, we need count, sum, and sum of squares
                    batch_size = adv.numel()
                    batch_size_tensor = torch.tensor([batch_size], device=accelerator.device)
                    adv_sum_tensor = torch.tensor([adv.sum().item()], device=accelerator.device)
                    adv_sq_sum_tensor = torch.tensor([(adv ** 2).sum().item()], device=accelerator.device)

                    # Debug prints for initial values
                    # accelerator.print(f"[DEBUG] Initial values:")
                    # accelerator.print(f"  batch_size: {batch_size}")
                    # accelerator.print(f"  adv_sum: {adv_sum_tensor.item()}")
                    # accelerator.print(f"  adv_sq_sum: {adv_sq_sum_tensor.item()}")
                    # accelerator.print(f"  adv contains NaN: {torch.isnan(adv).any().item()}")
                    # accelerator.print(f"  adv contains Inf: {torch.isinf(adv).any().item()}")

                    # All-reduce to get global statistics
                    dist.all_reduce(batch_size_tensor, op=dist.ReduceOp.SUM)
                    dist.all_reduce(adv_sum_tensor, op=dist.ReduceOp.SUM)
                    dist.all_reduce(adv_sq_sum_tensor, op=dist.ReduceOp.SUM)

                    # Debug prints after all-reduce
                    # accelerator.print(f"[DEBUG] After all-reduce:")
                    # accelerator.print(f"  global_batch_size: {batch_size_tensor.item()}")
                    # accelerator.print(f"  global_sum: {adv_sum_tensor.item()}")
                    # accelerator.print(f"  global_sq_sum: {adv_sq_sum_tensor.item()}")

                    # Calculate global mean and standard deviation correctly
                    global_batch_size = batch_size_tensor.item()
                    if global_batch_size == 0:
                        accelerator.print("[WARNING] global_batch_size is 0!")
                        global_mean = 0.0
                        global_std = 1.0
                        adv_mean = 0.0
                        adv_std = 1.0
                    else:
                        global_mean = adv_sum_tensor.item() / global_batch_size
                        global_var = (adv_sq_sum_tensor.item() / global_batch_size) - (global_mean ** 2)

                        # Debug prints for variance calculation
                        # accelerator.print(f"[DEBUG] Variance calculation:")
                        # accelerator.print(f"  global_mean: {global_mean}")
                        # accelerator.print(f"  global_var: {global_var}")

                        if global_var < 0:
                            accelerator.print(f"[WARNING] global_var is negative: {global_var}")
                            global_var = 0.0

                        global_std = torch.sqrt(torch.tensor(global_var) + 1e-8)

                        # Use global statistics
                        adv_mean = global_mean
                        adv_std = global_std.item()

                    # Final debug prints
                    # accelerator.print(f"[DEBUG] Final statistics:")
                    # accelerator.print(f"  adv_mean: {adv_mean}")
                    # accelerator.print(f"  adv_std: {adv_std}")

                adv = (adv - adv_mean) / (adv_std + 1e-8)  # Add small epsilon to avoid division by zero

        # Importance sampling ratio for this batch
        batch_flat_log_prob = batch_flat_log_prob.cuda().reshape([current_batch_size, 1])
        ratio = torch.exp(batch_log_prob_out.reshape([current_batch_size, 1]) - batch_flat_log_prob)

        # Clip ratio to prevent extreme policy updates
        ratio = torch.clamp(ratio, 1.0 - args.ratio_clip, 1.0 + args.ratio_clip)

        # Policy loss, no minus for we want FID to get lower
        loss = (
            args.reward_weight
            * adv.detach().float()
            * ratio.float().reshape([current_batch_size, 1])
        ).mean()

        # Add KL regularization after warm-up
        if count > args.kl_warmup:
            loss += args.kl_weight * batch_kl.mean()

        # Scale by gradient accumulation steps
        loss = loss / (args.gradient_accumulation_steps * ((total_flat_samples + args.p_batch_size - 1) // args.p_batch_size))

        # Backpropagate
        accelerator.backward(loss)

        return ratio, batch_kl, loss, adv_before_norm, adv

    total_steps = 0
    # Process in batches of p_batch_size
    for i in range(0, total_flat_samples, args.p_batch_size):
        total_steps += 1
        # Check if this is the last step
        is_last_step = (i + args.p_batch_size >= total_flat_samples)

        # Get end index for current batch (handle last batch that might be smaller)
        end_idx = min(i + args.p_batch_size, total_flat_samples)
        current_batch_size = end_idx - i

        # Get batch samples
        batch_flat_state = flat_state[i:end_idx]
        batch_flat_next_state = flat_next_state[i:end_idx]
        batch_flat_timestep = flat_timestep[i:end_idx]
        batch_flat_final_reward = flat_final_reward[i:end_idx]
        batch_flat_unconditional = flat_unconditional[i:end_idx]
        batch_flat_guided = flat_guided[i:end_idx]
        batch_flat_log_prob = flat_log_prob[i:end_idx]

        is_last_step = True     # FIXME: 临时强制每步都同步

        # Use no_sync for all steps except the last one to delay gradient synchronization
        if not is_last_step and is_ddp:
            with accelerator.no_sync(pipe.unet):
                ratio, batch_kl, loss, adv_before_norm, adv = process_batch(
                    batch_flat_state, batch_flat_next_state, batch_flat_timestep,
                    batch_flat_final_reward, batch_flat_unconditional, batch_flat_guided,
                    batch_flat_log_prob, current_batch_size
                )
        else:
            # For the last step, perform normal gradient synchronization
            ratio, batch_kl, loss, adv_before_norm, adv = process_batch(
                batch_flat_state, batch_flat_next_state, batch_flat_timestep,
                batch_flat_final_reward, batch_flat_unconditional, batch_flat_guided,
                batch_flat_log_prob, current_batch_size
            )

        # Track statistics for logging
        total_ratio += ratio.mean().item() * current_batch_size
        total_kl += batch_kl.mean().item() * current_batch_size
        total_p_loss += loss.item()

        # print GPU memory usage
        accelerator.print(f"[{i}/{total_flat_samples}] GPU memory usage:")
        # accelerator.print(f"  allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        # accelerator.print(f"  reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        # accelerator.print(f"  max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
        accelerator.print(f"  loss: {loss.item()}")
        accelerator.print(f"  adv_before_norm: {adv_before_norm}")
        accelerator.print(f"  adv: {adv.mean().item()}")

    # Update statistics for logging, normalizing by total processed samples
    tpfdata.tot_ratio += total_ratio / (total_steps * policy_steps)
    tpfdata.tot_kl += total_kl / (total_steps * policy_steps)
    tpfdata.tot_p_loss += total_p_loss / (total_steps * policy_steps)

def _train_policy_func_flat(
    args,
    state_dict,
    pipe,
    unet_copy,
    is_ddp,
    count,
    policy_steps,
    accelerator,
    tpfdata
):
    """Trains the policy function on collected data.

    Args:
        args: Arguments.
        state_dict: Dictionary containing collected experiences.
        pipe: Diffusion pipeline.
        unet_copy: Original UNet model for KL calculation.
        is_ddp: Whether the model is using DistributedDataParallel.
        count: Current training step.
        policy_steps: Total number of policy steps.
        accelerator: Accelerator object.
        tpfdata: Object to track training statistics.
        value_function: Optional value function for advantage computation.

    Returns:
        None, updates the policy model in-place.
    """
    # Debug variables
    debug_vars = {}

    with torch.no_grad():
        # Get batch indices
        if args.use_non_repeating_samples and "policy_indices" in state_dict:
            # Get the next batch of indices
            start_idx = state_dict["policy_index_position"]
            policy_indices_len = len(state_dict["policy_indices"])

            # Ensure we get exactly p_batch_size indices, using modulo if necessary
            indices = []
            for i in range(args.p_batch_size):
                idx = (start_idx + i) % policy_indices_len
                indices.append(state_dict["policy_indices"][idx])
            indices = torch.tensor(indices)

            # Update position
            state_dict["policy_index_position"] = (start_idx + args.p_batch_size) % policy_indices_len
        else:
            # Use random sampling
            indices = get_random_indices(state_dict["state"].shape[0], args.p_batch_size)

        # Get batch data
        batch_state = state_dict["state"][indices]
        batch_next_state = state_dict["next_state"][indices]
        batch_timestep = state_dict["timestep"][indices]
        batch_final_reward = state_dict["final_reward"][indices]
        batch_unconditional_prompt_embeds = state_dict["unconditional_prompt_embeds"][indices]
        batch_guided_prompt_embeds = state_dict["guided_prompt_embeds"][indices]
        batch_prompt_embeds = torch.cat([batch_unconditional_prompt_embeds, batch_guided_prompt_embeds])
        batch_log_prob = state_dict["log_prob"][indices]

        if DISTRL_DEBUG_NAN:
            debug_vars["indices"] = indices
            debug_vars["batch_final_reward"] = batch_final_reward.detach().cpu().clone()
            debug_vars["batch_log_prob"] = batch_log_prob.detach().cpu().clone()

    # Calculate log probability and KL regularization
    log_prob, kl_regularizer = pipe.forward_calculate_logprob(
        prompt_embeds=batch_prompt_embeds.cuda(),
        latents=batch_state.cuda(),
        next_latents=batch_next_state.cuda(),
        ts=batch_timestep.cuda(),
        unet_copy=unet_copy,
        num_inference_steps=args.num_inference_steps,
        is_ddp=is_ddp,
    )

    if DISTRL_DEBUG_NAN:
        debug_vars["log_prob"] = log_prob.detach().cpu().clone()
        debug_vars["kl_regularizer"] = kl_regularizer.detach().cpu().clone()
        debug_vars["log_prob_has_nan"] = torch.isnan(log_prob).any().item()
        debug_vars["kl_has_nan"] = torch.isnan(kl_regularizer).any().item()

    with torch.no_grad():
        # Expand reward to match log_prob shape
        adv = batch_final_reward.cuda().reshape([args.p_batch_size, 1])

        if DISTRL_DEBUG_NAN:
            debug_vars["adv_initial"] = adv.detach().cpu().clone()
            debug_vars["adv_initial_has_nan"] = torch.isnan(adv).any().item()

        # Normalize advantage by (reward - mean) / std when grpo_flag is True
        if args.grpo_flag in [1, 3]:
            if DISTRL_DEBUG_NAN:
                debug_vars["adv_mean"] = adv.mean().item()
                debug_vars["adv_std"] = adv.std().item()

            # Synchronize mean and std across GPUs when using distributed training with grpo_flag=2
            if is_ddp:
                # Gather all advantage values from all processes
                world_size = dist.get_world_size()
                gathered_adv_list = [torch.zeros_like(adv) for _ in range(world_size)]

                # Gather all advantage values
                dist.all_gather(gathered_adv_list, adv)

                # Concatenate all gathered advantages
                gathered_adv = torch.cat(gathered_adv_list, dim=0)

                if args.grpo_flag == 3:
                    # remove duplicates
                    gathered_adv = torch.unique(gathered_adv)

                # Calculate global mean and std directly from the combined tensor
                adv_mean = gathered_adv.mean().item()
                adv_std = gathered_adv.std().item()

                if DISTRL_DEBUG_NAN:
                    debug_vars["gathered_adv_shape"] = gathered_adv.shape
                    debug_vars["global_mean"] = adv_mean
                    debug_vars["global_std"] = adv_std
                    debug_vars["world_size"] = world_size
            else:
                advs = adv
                if args.grpo_flag == 3:
                    advs = torch.unique(advs)
                adv_mean = advs.mean().item()
                adv_std = advs.std().item()

            # Track advantage mean and std for logging
            tpfdata.tot_adv_mean += adv_mean / policy_steps
            tpfdata.tot_adv_std += adv_std / policy_steps

            adv = (adv - adv_mean) / (adv_std + 1e-8)  # Add small epsilon to avoid division by zero

            if DISTRL_DEBUG_NAN:
                debug_vars["adv_normalized"] = adv.detach().cpu().clone()
                debug_vars["adv_normalized_has_nan"] = torch.isnan(adv).any().item()

    # Importance sampling ratio
    ratio = torch.exp(log_prob - batch_log_prob.cuda())
    # print(f"ratio: {ratio}, timestep: {batch_timestep}")

    if DISTRL_DEBUG_NAN:
        debug_vars["ratio_before_clip"] = ratio.detach().cpu().clone()
        debug_vars["ratio_before_clip_has_nan"] = torch.isnan(ratio).any().item()

    # Store unclipped ratio values
    tpfdata.unclipped_ratio_values.extend(ratio.detach().cpu().flatten().tolist())

    # Clip ratio to prevent extreme policy updates
    ratio = torch.clamp(ratio, 1.0 - args.ratio_clip, 1.0 + args.ratio_clip)

    if DISTRL_DEBUG_NAN:
        debug_vars["ratio_after_clip"] = ratio.detach().cpu().clone()
        debug_vars["ratio_after_clip_has_nan"] = torch.isnan(ratio).any().item()

    # Store clipped ratio values
    tpfdata.ratio_values.extend(ratio.detach().cpu().flatten().tolist())

    # Store advantage values
    tpfdata.adv_values.extend(adv.detach().cpu().flatten().tolist())

    if not DISTRL_DEBUG_ALIGN_SFT_RL_DOSFT:
        # policy loss
        loss = -(
            args.reward_weight
            * adv.detach().float()
            * ratio.float().reshape([args.p_batch_size, 1])
        ).mean()

    if DISTRL_DEBUG_ALIGN_SFT_RL:
        # forward like sft and calculate loss
        import torch.nn.functional as F

        # Setup timestep schedule for SIT
        last_step_size = 0.04
        t_steps = torch.linspace(0, 1 - last_step_size, args.num_inference_steps, device=accelerator.device)
        dt = t_steps[1] - t_steps[0]

        if not DISTRL_DEBUG_ALIGN_SFT_RL_DOSFT:
            fake_rl_loss = -2 * torch.log(ratio) * (kl_regularizer.dist_.std.view(ratio.shape) ** 2 + 1e-6) + (
                (batch_next_state.cuda() - (batch_state.cuda() + kl_regularizer.dist_.drift * kl_regularizer.dist_.dt.cuda()).detach()) ** 2
            ).view(ratio.shape[0], -1).mean(dim=1)

        # Get timestep values for current batch
        t = t_steps[batch_timestep]

        # Convert prompt embeddings to class labels (for SIT)
        # batch_guided_prompt_embeds shape: [p_batch_size, 77, 768] or similar
        # For SIT, we need class indices, so we use argmax if it's one-hot encoded
        y = batch_guided_prompt_embeds.argmax(dim=-1)  # Take first token's class

        with accelerator.autocast():
            # Forward pass through UNet
            model_output = pipe.unet(x=batch_state.cuda(), t=t.cuda(), y=y.cuda())

        if not DISTRL_DEBUG_ALIGN_SFT_RL_SFT_WITH_FAKE_RL_LOSS:
            if DISTRL_DEBUG_ALIGN_SFT_RL_SFT_SDE_DRIFT:
                model_kwargs = dict(y=y.cuda())
                pred = pipe.sampler.sde.Euler_Maruyama_step_for_sft_with_sde_drift(
                    batch_state.cuda(), t.cuda(), pipe.unet.forward, **model_kwargs
                )
            else:
                # Calculate prediction: x_next = x + u * dt
                pred = batch_state.cuda() + model_output * dt
            targets = batch_next_state.cuda()

            # Calculate MSE loss per sample
            sft_loss_per_sample = F.mse_loss(pred, targets, reduction="none")  # (bs, 4, 32, 32)
            sft_loss_per_sample = sft_loss_per_sample.view(sft_loss_per_sample.size(0), -1).mean(dim=1)  # (bs,)
            sft_loss = sft_loss_per_sample.mean()

            if DISTRL_DEBUG_ALIGN_SFT_RL_DOSFT:
                loss = sft_loss

        else:
            fake_rl_loss = -2 * torch.log(ratio) * (kl_regularizer.dist_.std.view(ratio.shape) ** 2 + 1e-6) + (
                (batch_next_state.cuda() - (batch_state.cuda() + model_output.detach() * kl_regularizer.dist_.dt.cuda()).detach()) ** 2
            ).view(ratio.shape[0], -1).mean(dim=1)
            loss = fake_rl_loss.mean()

        # accelerator.print(f"SFT loss: {sft_loss.item():.6f}, RL loss: {loss.item():.6f}")

    if DISTRL_DEBUG_NAN:
        debug_vars["loss_before_kl"] = loss.detach().cpu().clone()
        debug_vars["loss_before_kl_has_nan"] = torch.isnan(loss).any().item()

    # Add KL regularization after warm-up
    if args.kl_weight > 0 and count > args.kl_warmup:
        loss += args.kl_weight * kl_regularizer.mean()

        if DISTRL_DEBUG_NAN:
            debug_vars["loss_after_kl"] = loss.detach().cpu().clone()
            debug_vars["loss_after_kl_has_nan"] = torch.isnan(loss).any().item()

    # Scale by gradient accumulation steps
    loss = loss / (args.gradient_accumulation_steps)

    if DISTRL_DEBUG_NAN:
        debug_vars["final_loss"] = loss.detach().cpu().clone()
        debug_vars["final_loss_has_nan"] = torch.isnan(loss).any().item()

        # Print debug information only if loss is NaN
        if torch.isnan(loss).any().item():
            accelerator.print("=" * 80)
            accelerator.print(f"NaN detected in loss! Debug information:")
            for key, value in debug_vars.items():
                if isinstance(value, torch.Tensor):
                    if value.numel() <= 10:  # Only print small tensors completely
                        accelerator.print(f"  {key}: {value}")
                    else:
                        # For larger tensors, print shape, min, max, mean, std
                        accelerator.print(f"  {key}: shape={value.shape}, "
                                        f"min={value.min().item():.6f}, "
                                        f"max={value.max().item():.6f}, "
                                        f"mean={value.mean().item():.6f}, "
                                        f"std={value.std().item():.6f}")
                else:
                    accelerator.print(f"  {key}: {value}")
            accelerator.print("=" * 80)

    # Backpropagate
    accelerator.backward(loss)

    # Track statistics for logging
    tpfdata.tot_ratio += ratio.mean().item() / policy_steps
    tpfdata.tot_kl += kl_regularizer.mean().item() / policy_steps
    tpfdata.tot_p_loss += loss.item() / policy_steps
