# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import traceback
import uuid
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Optional

import numpy as np
import ray
import torch
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm

from recipe.spin import core_algos
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.metric_utils import (
    compute_throughout_metrics,
    compute_timing_metrics,
    process_validation_metrics,
    reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import Role
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger

WorkerType = type[Worker]


class AdvantageEstimator(str, Enum):
    """
    Using an enumeration class to avoid spelling errors in adv_estimator
    """

    GAE = "gae"
    GRPO = "grpo"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
    REMAX = "remax"
    RLOO = "rloo"


@dataclass
class ResourcePoolManager:
    """
    Define a resource pool specification. Resource pool will be initialized first.
    Mapping
    """

    resource_pool_spec: dict[str, list[int]]
    mapping: dict[Role, str]
    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)

    def create_resource_pool(self):
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different
            # WorkerGroup for different models
            resource_pool = RayResourcePool(
                process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
            )
            self.resource_pool_dict[resource_pool_name] = resource_pool

        self._check_resource_available()

    def get_resource_pool(self, role: Role) -> RayResourcePool:
        """Get the resource pool of the worker_cls"""
        return self.resource_pool_dict[self.mapping[role]]

    def get_n_gpus(self) -> int:
        """Get the number of gpus in this cluster."""
        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])

    def _check_resource_available(self):
        """Check if the resource pool can be satisfied in this ray cluster."""
        node_available_resources = ray.state.available_resources_per_node()
        node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}

        # check total required gpus can be satisfied
        total_available_gpus = sum(node_available_gpus.values())
        total_required_gpus = sum(
            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
        )
        if total_available_gpus < total_required_gpus:
            raise ValueError(
                f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}"
            )

        # check each resource pool can be satisfied, O(#resource_pools * #nodes)
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
            for node, available_gpus in node_available_gpus.items():
                if available_gpus >= num_gpus:
                    node_available_gpus[node] -= num_gpus
                    num_nodes -= 1
                    if num_nodes == 0:
                        break
            if num_nodes > 0:
                raise ValueError(
                    f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this "
                    f"ray cluster"
                )


def _compute_response_info(batch: DataProto) -> dict[str, Any]:
    """Placeholder: Computes prompt and response lengths."""
    try:
        # Assuming 'prompts' and 'responses' keys exist after generation/union
        prompt_len = batch.batch["prompts"].shape[1]
        resp_len = batch.batch["responses"].shape[1]
        # This is simplified - real implementation might use attention masks
        # to get actual lengths per sample.
        batch_size = batch.batch.batch_size[0]
        prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)
        response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)

        # Try getting actual lengths from attention mask if possible (more accurate)
        if "response_mask" in batch.batch:
            response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float()
            # if "attention_mask" in batch.batch and "response_mask" in batch.batch:
            # full_mask = batch.batch["attention_mask"]
            # resp_mask = batch.batch["response_mask"]
            # Infer prompt mask length based on where response mask starts or total length
            # This logic depends heavily on how your masks are constructed.
            # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor
            # Fallback to using prompt shape if mask logic is complex:
            prompt_lengths_tensor = torch.tensor(
                [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device
            )

        return {
            "prompt_length": prompt_lengths_tensor,
            "response_length": response_lengths_tensor,
            "max_response_length": resp_len,
            "max_prompt_length": prompt_len,  # Or from config if fixed padding
        }
    except KeyError as e:
        print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.")
        # Return default/dummy values if keys are missing
        b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1
        max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0
        max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0
        return {
            "prompt_length": torch.zeros(b_size),
            "response_length": torch.zeros(b_size),
            "max_response_length": max_resp,
            "max_prompt_length": max_prompt,
        }


# --- Modified Metric Function ---
def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
    """
    Computes and returns metrics relevant for the DPO-like process.
    Assumes 'batch' contains results after generation and preference marking,
    potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.
    Removes PPO-specific advantage/return/critic metrics.
    """
    print("---- [DEBUG] Computing DPO Data Metrics ----")
    metrics = {}
    try:
        # --- Scores and Rewards (from reward_fn) ---
        if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None:
            sequence_score = batch.batch["token_level_scores"].sum(-1)
            metrics.update(
                {
                    "reward/score/mean": torch.mean(sequence_score).item(),
                    "reward/score/max": torch.max(sequence_score).item(),
                    "reward/score/min": torch.min(sequence_score).item(),
                }
            )
        else:
            print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.")

        if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None:
            sequence_reward = batch.batch["token_level_rewards"].sum(-1)
            metrics.update(
                {
                    "reward/rewards/mean": torch.mean(sequence_reward).item(),
                    "reward/rewards/max": torch.max(sequence_reward).item(),
                    "reward/rewards/min": torch.min(sequence_reward).item(),
                }
            )
        else:
            print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.")

        # --- DPO Specific Metrics (if stored previously) ---
        if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None:
            metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item()
        else:
            print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.")

        if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None:
            metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item()
        else:
            print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.")

        if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None:
            metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item()
        else:
            print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.")

        # Add metrics based on the 'preferences' mask if available
        # if "preferences" in batch.batch and batch.batch["preferences"] is not None:
        # prefs_mask = batch.batch["preferences"]  # Shape [batch_size * n]
        # Calculate accuracy based on RM scores (assuming higher score -> True in mask)
        # Requires chosen/rejected scores to be available or recalculated
        # This is complex here, better calculated in the main loop or update function

        # --- Length Metrics ---
        response_info = _compute_response_info(batch)
        prompt_length = response_info["prompt_length"]
        response_length = response_info["response_length"]
        max_response_length = response_info["max_response_length"]
        max_prompt_length = response_info["max_prompt_length"]  # Use calculated or from config

        metrics.update(
            {
                "response_length/mean": torch.mean(response_length).item(),
                "response_length/max": torch.max(response_length).item(),
                "response_length/min": torch.min(response_length).item(),
                "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(),
                "prompt_length/mean": torch.mean(prompt_length).item(),
                "prompt_length/max": torch.max(prompt_length).item(),
                "prompt_length/min": torch.min(prompt_length).item(),
                # Prompt clip ratio might need adjustment based on how max_prompt_length is defined
                "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),
            }
        )

    except KeyError as e:
        print(f"ERROR in compute_dpo_data_metrics: Missing key {e}")
    except Exception as e:
        print(f"ERROR in compute_dpo_data_metrics: {e}")
        traceback.print_exc()

    print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----")
    return metrics


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
    responses = data.batch["responses"]
    response_length = responses.size(1)
    token_level_scores = data.batch["token_level_scores"]
    batch_size = data.batch.batch_size[0]
    attention_mask = data.batch["attention_mask"]
    response_mask = attention_mask[:, -response_length:]

    # compute kl between ref_policy and current policy
    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
    kld = core_algos.kl_penalty(
        data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
    )  # (batch_size, response_length)
    kld = kld * response_mask
    beta = kl_ctrl.value

    token_level_rewards = token_level_scores - beta * kld

    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence
    current_kl = torch.mean(current_kl, dim=0).item()

    # according to XXXX
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    data.batch["token_level_rewards"] = token_level_rewards

    metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}

    return data, metrics


def compute_response_mask(data: DataProto):
    responses = data.batch["responses"]
    response_length = responses.size(1)
    attention_mask = data.batch["attention_mask"]
    return attention_mask[:, -response_length:]


def compute_onlineDPO_pref(data: DataProto):
    """
    Wrapper to compute DPO preference and add it to the DataProto batch.
    Includes debugging prints.
    """
    # print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----")
    # print(f"  Input batch keys: {list(data.batch.keys())}")

    # Check inputs
    rewards_tensor = data.batch.get("token_level_rewards")
    mask_tensor = data.batch.get("response_mask")

    if rewards_tensor is None or mask_tensor is None:
        print("  ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!")
        # Handle error case - maybe return original data or raise?
        # Returning original data for now to potentially allow skipping
        return data

    try:
        preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)
        # Store the result
        data.batch["preferences"] = preferences

    except AttributeError:
        print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!")
        # Assign dummy value or raise error
        data.batch["preferences"] = None  # Indicate failure
    except Exception as e_pref:
        print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}")
        import traceback

        traceback.print_exc()
        data.batch["preferences"] = None  # Indicate failure

    # print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----")
    return data


@contextmanager
def _timer(name: str, timing_raw: dict[str, float]):
    with Timer(name=name, logger=None) as timer:
        yield
    timing_raw[name] = timer.last


class RaySPINTrainer:
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    # TODO: support each role have individual ray_worker_group_cls,
    # i.e., support different backend of different role
    def __init__(
        self,
        config,
        tokenizer,
        role_worker_mapping: dict[Role, WorkerType],
        resource_pool_manager: ResourcePoolManager,
        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
        processor=None,
        reward_fn=None,
        val_reward_fn=None,
        train_dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        collate_fn=None,
        train_sampler: Optional[Sampler] = None,
        device_name=None,
    ):
        # assert get_torch_device().is_available(), 'cuda must be available on driver'

        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn

        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
        assert self.hybrid_engine, "Currently, only support hybrid engine"

        if self.hybrid_engine:
            assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reference_policy = Role.RefPolicy in role_worker_mapping
        self.use_rm = Role.RewardModel in role_worker_mapping
        self.ray_worker_group_cls = ray_worker_group_cls
        self.validation_generations_logger = ValidationGenerationsLogger()
        self.async_rollout_mode = False
        self.device_name = device_name if device_name else self.config.trainer.device

        # define in-reward KL control
        # kl loss control currently not suppoorted
        if config.algorithm.use_kl_in_reward:
            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)

        self.use_critic = False
        self._validate_config()
        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

    def _validate_config(self):
        config = self.config
        # number of GPUs total
        n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes

        # 1. Check total batch size for data correctness
        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
        assert real_train_batch_size % n_gpus == 0, (
            f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
        )

        # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
        # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
        def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
            settings = {
                "actor_rollout_ref.actor": "micro_batch_size",
                "critic": "micro_batch_size",
                "reward_model": "micro_batch_size",
                "actor_rollout_ref.ref": "log_prob_micro_batch_size",
                "actor_rollout_ref.rollout": "log_prob_micro_batch_size",
            }

            if name in settings:
                param = settings[name]
                param_per_gpu = f"{param}_per_gpu"

                if mbs is None and mbs_per_gpu is None:
                    raise ValueError(
                        f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
                    )

                if mbs is not None and mbs_per_gpu is not None:
                    raise ValueError(
                        f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
                        f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
                        f"(the former is deprecated)."
                    )

        if not config.actor_rollout_ref.actor.use_dynamic_bsz:
            # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
            check_mutually_exclusive(
                config.actor_rollout_ref.actor.ppo_micro_batch_size,
                config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
                "actor_rollout_ref.actor",
            )

            if self.use_reference_policy:
                # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
                check_mutually_exclusive(
                    config.actor_rollout_ref.ref.log_prob_micro_batch_size,
                    config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
                    "actor_rollout_ref.ref",
                )

            #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
            check_mutually_exclusive(
                config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
                config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
                "actor_rollout_ref.rollout",
            )

        if self.use_critic and not config.critic.use_dynamic_bsz:
            # Check for critic micro-batch size conflicts
            check_mutually_exclusive(
                config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
            )

        # Check for reward model micro-batch size conflicts
        if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
            check_mutually_exclusive(
                config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
            )

        # Actor
        # check if train_batch_size is larger than ppo_mini_batch_size
        # if NOT dynamic_bsz, we must ensure:
        #    ppo_mini_batch_size is divisible by ppo_micro_batch_size
        #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus
        if not config.actor_rollout_ref.actor.use_dynamic_bsz:
            assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
            sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
            if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
                assert (
                    config.actor_rollout_ref.actor.ppo_mini_batch_size
                    % config.actor_rollout_ref.actor.ppo_micro_batch_size
                    == 0
                )
                assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus

        assert config.actor_rollout_ref.actor.loss_agg_mode in [
            "token-mean",
            "seq-mean-token-sum",
            "seq-mean-token-mean",
        ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"

        if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
            print("NOTICE: You have both enabled in-reward kl and kl loss.")

        # critic
        if self.use_critic and not config.critic.use_dynamic_bsz:
            assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
            sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
            if config.critic.ppo_micro_batch_size is not None:
                assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
                assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus

        # Check if use_remove_padding is enabled when using sequence parallelism for fsdp
        if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
            if (
                config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
                or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
            ):
                assert config.actor_rollout_ref.model.use_remove_padding, (
                    "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
                )

        if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
            if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
                assert config.critic.model.use_remove_padding, (
                    "When using sequence parallelism for critic, you must enable `use_remove_padding`."
                )

        if config.data.get("val_batch_size", None) is not None:
            print(
                "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
                "as a whole batch, which will schedule the memory themselves."
            )

        # check eval config
        if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
            assert config.actor_rollout_ref.rollout.temperature > 0, (
                "validation gen temperature should be greater than 0 when enabling do_sample"
            )

        print("[validate_config] All configuration checks passed successfully!")

    def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
        """
        Creates the train and validation dataloaders.
        """
        # TODO: we have to make sure the batch size is divisible by the dp size
        from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler

        if train_dataset is None:
            train_dataset = create_rl_dataset(
                self.config.data.train_files, self.config.data, self.tokenizer, self.processor
            )
        if val_dataset is None:
            val_dataset = create_rl_dataset(
                self.config.data.val_files, self.config.data, self.tokenizer, self.processor
            )
        self.train_dataset, self.val_dataset = train_dataset, val_dataset

        if train_sampler is None:
            train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
        if collate_fn is None:
            from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

            collate_fn = default_collate_fn

        self.train_dataloader = StatefulDataLoader(
            dataset=self.train_dataset,
            batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
            num_workers=self.config.data.get("dataloader_num_workers", 8),
            drop_last=True,
            collate_fn=collate_fn,
            sampler=train_sampler,
        )

        val_batch_size = self.config.data.val_batch_size  # Prefer config value if set
        if val_batch_size is None:
            val_batch_size = len(self.val_dataset)

        self.val_dataloader = StatefulDataLoader(
            dataset=self.val_dataset,
            batch_size=val_batch_size,
            num_workers=self.config.data.get("dataloader_num_workers", 8),
            shuffle=False,
            drop_last=False,
            collate_fn=collate_fn,
        )

        assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
        assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"

        print(
            f"Size of train dataloader: {len(self.train_dataloader)}, "
            f"Size of val dataloader: {len(self.val_dataloader)}"
        )

        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

        if self.config.trainer.total_training_steps is not None:
            total_training_steps = self.config.trainer.total_training_steps

        self.total_training_steps = total_training_steps
        print(f"Total training steps: {self.total_training_steps}")

        try:
            OmegaConf.set_struct(self.config, True)
            with open_dict(self.config):
                if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
                    self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
                if OmegaConf.select(self.config, "critic.optim"):
                    self.config.critic.optim.total_training_steps = total_training_steps
        except Exception as e:
            print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")

    def _maybe_log_val_generations(self, inputs, outputs, scores):
        """Log a table of validation samples to the configured logger (wandb or swanlab)"""

        generations_to_log = self.config.trainer.log_val_generations

        if generations_to_log == 0:
            return

        import numpy as np

        # Create tuples of (input, output, score) and sort by input text
        samples = list(zip(inputs, outputs, scores, strict=True))
        samples.sort(key=lambda x: x[0])  # Sort by input text

        # Use fixed random seed for deterministic shuffling
        rng = np.random.RandomState(42)
        rng.shuffle(samples)

        # Take first N samples after shuffling
        samples = samples[:generations_to_log]

        # Log to each configured logger
        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)

    def _validate(self):
        data_source_lst = []
        reward_extra_infos_dict: dict[str, list] = defaultdict(list)

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(
                repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
            )

            # we only do validation on rule-based rm
            if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
                return {}

            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            # TODO: Can we keep special tokens except for padding tokens?
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
            non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
            if "multi_modal_inputs" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
            if "raw_prompt" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("raw_prompt")
            if "tools_kwargs" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("tools_kwargs")
            test_gen_batch = test_batch.pop(
                batch_keys=batch_keys_to_pop,
                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
            )

            test_gen_batch.meta_info = {
                "eos_token_id": self.tokenizer.eos_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
                "recompute_log_prob": False,
                "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                "validate": True,
            }
            print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

            # pad to be divisible by dp_size
            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            if not self.async_rollout_mode:
                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
            else:
                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)

            # unpad
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print("validation generation end")

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # evaluate using reward_function
            result = self.val_reward_fn(test_batch, return_dict=True)
            reward_tensor = result["reward_tensor"]
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_extra_infos_dict["reward"].extend(scores)
            if "reward_extra_info" in result:
                for key, lst in result["reward_extra_info"].items():
                    reward_extra_infos_dict[key].extend(lst)

            data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        # dump generations
        val_data_dir = self.config.trainer.get("validation_data_dir", None)
        if val_data_dir:
            self._dump_generations(
                inputs=sample_inputs,
                outputs=sample_outputs,
                scores=sample_scores,
                reward_extra_infos_dict=reward_extra_infos_dict,
                dump_path=val_data_dir,
            )

        for key_info, lst in reward_extra_infos_dict.items():
            assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"

        data_sources = np.concatenate(data_source_lst, axis=0)
        print(f"DEBUG: Data sources shape: {data_sources.shape}")  # Added Print
        print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}")  # Added Print

        data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
        print(
            f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}"
        )  # Added Print
        metric_dict = {}
        for data_source, var2metric2val in data_src2var2metric2val.items():
            core_var = "acc" if "acc" in var2metric2val else "reward"
            for var_name, metric2val in var2metric2val.items():
                n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
                for metric_name, metric_val in metric2val.items():
                    if (
                        (var_name == core_var)
                        and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
                        and (f"@{n_max}" in metric_name)
                    ):
                        metric_sec = "val-core"
                    else:
                        metric_sec = "val-aux"
                    pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
                    metric_dict[pfx] = metric_val

        return metric_dict

    def init_workers(self):
        """Init resource pool and worker group"""
        self.resource_pool_manager.create_resource_pool()

        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}

        # create actor and rollout
        if self.hybrid_engine:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
            actor_rollout_cls = RayClassWithInitArgs(
                cls=self.role_worker_mapping[Role.ActorRollout],
                config=self.config.actor_rollout_ref,
                role="actor_rollout",
            )
            self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
        else:
            raise NotImplementedError

        # create critic
        if self.use_critic:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
            self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls

        # create reference policy if needed
        if self.use_reference_policy:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
            ref_policy_cls = RayClassWithInitArgs(
                self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref"
            )
            self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls

        # create a reward model if reward_fn is None
        if self.use_rm:
            # we create a RM here
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
            self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls

        # initialize WorkerGroup
        # NOTE: if you want to use a different resource pool for each role, which can support different
        # parallel size,
        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to
        # different worker groups.
        # See XXXX for more information.
        all_wg = {}
        self.wg_dicts = []
        wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup
        if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
            wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
        wg_kwargs["device_name"] = self.device_name

        for resource_pool, class_dict in self.resource_pool_to_cls.items():
            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
            wg_dict = self.ray_worker_group_cls(
                resource_pool=resource_pool,
                ray_cls_with_init=worker_dict_cls,
                **wg_kwargs,
            )
            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
            all_wg.update(spawn_wg)
            # keep the referece of WorkerDict to support ray >= 2.31. Ref: XXXX
            self.wg_dicts.append(wg_dict)

        if self.use_critic:
            self.critic_wg = all_wg["critic"]
            self.critic_wg.init_model()

        if self.use_reference_policy:
            self.ref_policy_wg = all_wg["ref"]
            self.ref_policy_wg.init_model()

        if self.use_rm:
            self.rm_wg = all_wg["rm"]
            self.rm_wg.init_model()

        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
        self.actor_rollout_wg = all_wg["actor_rollout"]
        self.actor_rollout_wg.init_model()

    def _save_checkpoint(self):
        # path: given_path + `/global_step_{global_steps}` + `/actor`
        local_global_step_folder = os.path.join(
            self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
        )

        print(f"local_global_step_folder: {local_global_step_folder}")
        actor_local_path = os.path.join(local_global_step_folder, "actor")

        actor_remote_path = (
            None
            if self.config.trainer.default_hdfs_dir is None
            else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
        )

        remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
        if remove_previous_ckpt_in_save:
            print(
                "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and "
                "max_critic_ckpt_to_keep=1 instead"
            )
        max_actor_ckpt_to_keep = (
            self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
        )
        max_critic_ckpt_to_keep = (
            self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
        )

        self.actor_rollout_wg.save_checkpoint(
            actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep
        )

        if self.use_critic:
            critic_local_path = os.path.join(local_global_step_folder, "critic")
            critic_remote_path = (
                None
                if self.config.trainer.default_hdfs_dir is None
                else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
            )
            self.critic_wg.save_checkpoint(
                critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
            )

        # save dataloader
        dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
        dataloader_state_dict = self.train_dataloader.state_dict()
        torch.save(dataloader_state_dict, dataloader_local_path)

        # latest checkpointed iteration tracker (for atomic usage)
        local_latest_checkpointed_iteration = os.path.join(
            self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
        )
        with open(local_latest_checkpointed_iteration, "w") as f:
            f.write(str(self.global_steps))

    def _load_checkpoint(self):
        if self.config.trainer.resume_mode == "disable":
            return 0

        # load from hdfs
        if self.config.trainer.default_hdfs_dir is not None:
            raise NotImplementedError("load from hdfs is not implemented yet")
        else:
            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path
            if not os.path.isabs(checkpoint_folder):
                working_dir = os.getcwd()
                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest

        # find global_step_folder
        if self.config.trainer.resume_mode == "auto":
            if global_step_folder is None:
                print("Training from scratch")
                return 0
        else:
            if self.config.trainer.resume_mode == "resume_path":
                assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
                assert "global_step_" in self.config.trainer.resume_from_path, (
                    "resume ckpt must specify the global_steps"
                )
                global_step_folder = self.config.trainer.resume_from_path
                if not os.path.isabs(global_step_folder):
                    working_dir = os.getcwd()
                    global_step_folder = os.path.join(working_dir, global_step_folder)
        print(f"Load from checkpoint folder: {global_step_folder}")
        # set global step
        self.global_steps = int(global_step_folder.split("global_step_")[-1])

        print(f"Setting global step to {self.global_steps}")
        print(f"Resuming from {global_step_folder}")

        actor_path = os.path.join(global_step_folder, "actor")
        critic_path = os.path.join(global_step_folder, "critic")
        # load actor
        self.actor_rollout_wg.load_checkpoint(
            actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
        )
        # load critic
        if self.use_critic:
            self.critic_wg.load_checkpoint(
                critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
            )

        # load dataloader,
        # TODO: from remote not implemented yet
        dataloader_local_path = os.path.join(global_step_folder, "data.pt")
        if os.path.exists(dataloader_local_path):
            dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
            self.train_dataloader.load_state_dict(dataloader_state_dict)
        else:
            print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")

    def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
        """Reorder the data on single controller such that each dp rank gets similar total tokens"""
        attention_mask = batch.batch["attention_mask"]
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(
            global_seqlen_lst, k_partitions=world_size, equal_size=True
        )
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(
            seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
        )
        metrics.update(global_balance_stats)

    def fit_dpo(self):  # Renamed for clarity as standard PPO loop
        """
        The training loop of Online DPO using a periodically updated reference model.
        The driver process calls worker groups for computation.
        Advantage computation is replaced by DPO logic.
        """
        import traceback  # Ensure traceback is imported

        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking

        # Initialize logger
        logger = None
        try:
            logger = Tracking(
                project_name=self.config.trainer.project_name,
                experiment_name=self.config.trainer.experiment_name,
                default_backend=self.config.trainer.logger,
                config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),
            )
        except Exception as e:
            print(f"Warning: Failed to initialize logger: {e}")

        self.global_steps = 0
        # Load checkpoint before doing anything
        loaded_step = self._load_checkpoint()
        self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1
        print(
            f"Starting Online DPO training from global step {self.global_steps}. "
            f"Total steps: {self.total_training_steps}"
        )
        print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}")

        # Check if reference policy is configured correctly for this mode
        if not self.use_reference_policy:
            print(
                "WARNING: 'use_reference_policy' is False. Periodic reference model update requires a "
                "reference policy worker. DPO updates might fail or use incorrect logic."
            )
            # Consider raising an error if strict adherence is required:
            # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True "
            #                  "and a configured reference worker.")

        # Perform validation before training
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            print("Running validation before Online DPO training...")
            val_metrics = self._validate()
            pprint(f"Initial validation metrics: {val_metrics}")
            if logger and val_metrics:
                logger.log(data=val_metrics, step=max(0, self.global_steps - 1))
            if self.config.trainer.get("val_only", False):
                print("Validation only mode enabled. Exiting training.")
                if logger and hasattr(logger, "finish"):
                    logger.finish()
                return

        # Add tqdm progress bar
        progress_bar = tqdm(
            total=self.total_training_steps,
            initial=self.global_steps,
            desc="Online DPO Training Progress",
            position=0,
            leave=True,
        )

        last_val_metrics = None
        should_stop = False

        for epoch in range(self.config.trainer.total_epochs):
            if should_stop:
                break
            print(f"--- Starting Online DPO Epoch {epoch} ---")
            try:
                train_iterator = iter(self.train_dataloader)
            except TypeError:
                print("Warning: Dataloader is not iterable.")
                train_iterator = self.train_dataloader  # Fallback attempt

            for batch_idx, batch_dict in enumerate(train_iterator):
                if self.global_steps > self.total_training_steps:
                    should_stop = True
                    break

                metrics = {}
                timing_raw = {}
                step_timer = Timer(logger=None)
                ref_log_prob_computed = False  # Flag to track if ref log probs were computed

                try:  # Outer try-except for the whole step
                    step_timer.start()
                    with _timer("step", timing_raw):
                        batch: DataProto = DataProto.from_single_dict(batch_dict)
                        current_batch_size = batch.batch.batch_size[0]
                        print(
                            f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: "
                            f"{current_batch_size}"
                        )

                        # --- Reference Model Update ---
                        ref_update_freq = self.config.trainer.get("ref_update_freq", -1)
                        if (
                            self.use_reference_policy
                            and ref_update_freq > 0
                            and self.global_steps % ref_update_freq == 0
                        ):
                            print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...")
                            try:
                                # --- This requires careful implementation with FSDP ---
                                # 1. Save actor state dict (potentially to CPU memory or disk)
                                #    This needs to be done collectively across actor worker ranks.
                                #    The checkpoint_manager might be adaptable, or use FSDP APIs directly.
                                #    Example placeholder using a conceptual save/load mechanism:
                                actor_state_path = "/tmp/actor_state_mid"  # Temporary path
                                self.actor_rollout_wg.save_checkpoint(actor_state_path)  # Adapt save logic

                                # 2. Load the state dict onto the reference model worker group
                                #    This also needs collective loading on the ref worker ranks.
                                self.ref_policy_wg.load_checkpoint(actor_state_path, None, True)  # Adapt load logic

                                print(f"[Step {self.global_steps}] Reference Model Weights Updated.")
                                # Optionally remove the temporary state file
                                # os.remove(actor_state_path) # Needs rank-aware removal or shared storage

                            except Exception as sync_e:
                                print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}")
                                traceback.print_exc()

                        # Pop keys for generation
                        pop_batch_keys = ["input_ids", "attention_mask"]
                        if "position_ids" in batch.batch:
                            pop_batch_keys.append("position_ids")
                        pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else []
                        if "multi_modal_inputs" in batch.non_tensor_batch.keys():
                            pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"])
                        original_non_tensor_data = batch.non_tensor_batch
                        gen_batch = batch.pop(
                            batch_keys=pop_batch_keys,
                            non_tensor_batch_keys=pop_non_tensor_keys,
                        )
                        gen_batch = gen_batch.repeat(
                            repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
                        )
                        # (Add Debug prints for gen_batch if needed)

                        # Generate sequences (chosen/rejected pairs)
                        with _timer("gen", timing_raw):
                            try:
                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                                # (Add Debug prints for gen_batch_output if needed)
                            except Exception as gen_e:
                                print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!")
                                print(gen_e)
                                traceback.print_exc()
                                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                                step_timer.stop()
                                continue

                        # Combine original prompts with generated sequences
                        batch.non_tensor_batch = original_non_tensor_data  # Restore non-tensor data
                        batch.non_tensor_batch["uid"] = np.array(
                            [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object
                        )
                        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                        batch = batch.union(gen_batch_output)
                        # (Add Debug prints after union if needed)

                        # Compute response mask (needed for ref logprob calc and DPO prep)
                        batch.batch["response_mask"] = compute_response_mask(batch)

                        if self.config.trainer.balance_batch:
                            self._balance_batch(batch, metrics=metrics)

                        batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                        # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef
                        # fallback) ---
                        # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed
                        #       unless used for other metrics or a fallback. Keep it for now.
                        with _timer("policy_log_prob", timing_raw):
                            policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)
                            batch = batch.union(policy_log_prob_output)  # Adds 'old_log_probs'
                            # (Debug prints for old_log_probs)

                        # --- Compute Log Probs using the EXTERNAL Reference Model ---
                        if self.use_reference_policy:
                            with _timer("ref_log_prob_dpo", timing_raw):
                                # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----")
                                try:
                                    # 'batch' contains interleaved chosen/rejected sequences
                                    ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(
                                        batch
                                    )  # Returns DataProto with 'ref_log_prob'
                                    batch = batch.union(
                                        ref_log_prob_output
                                    )  # Adds 'ref_log_prob' key [batch_size * n, seq_len]
                                    ref_log_prob_computed = True  # Mark success
                                    # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: "
                                    #       f"{batch.batch['ref_log_prob'].shape} ----")
                                except Exception as ref_e:
                                    print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}")
                                    traceback.print_exc()
                                    batch.batch["ref_log_prob"] = None  # Mark as failed
                                    ref_log_prob_computed = False
                        else:
                            print(
                                "Warning: Skipping external reference log prob calculation as use_reference_policy "
                                "is False."
                            )
                            # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor

                        # --- Compute Rewards/Scores (used to determine preference) ---
                        with _timer("reward_calc", timing_raw):
                            # (Reward calculation logic using RM or reward_fn as before)
                            # ... Ensure this calculates 'token_level_rewards' or similar ...
                            if self.use_rm:
                                reward_tensor_rm = self.rm_wg.compute_rm_score(batch)
                                batch = batch.union(reward_tensor_rm)  # Adds 'rm_scores'

                            reward_extra_infos_dict = {}
                            try:
                                if self.reward_fn is None:
                                    #  print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! "
                                    #        f"Using dummy rewards. ----")
                                    # Use rm_scores if available, otherwise zeros
                                    reward_tensor = batch.batch.get(
                                        "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
                                    )
                                else:
                                    reward_result = self.reward_fn(batch, return_dict=True)
                                    reward_tensor = reward_result["reward_tensor"]  # Final combined reward
                                    reward_extra_infos_dict = reward_result.get("reward_extra_info", {})

                            except Exception:
                                # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '
                                #       f'Using dummy rewards. ----')
                                traceback.print_exc()
                                reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
                                reward_extra_infos_dict = {}

                            # Use 'token_level_rewards' as the key for preference calculation
                            batch.batch["token_level_rewards"] = reward_tensor
                            if reward_extra_infos_dict:
                                batch.non_tensor_batch.update(
                                    {k: np.array(v) for k, v in reward_extra_infos_dict.items()}
                                )

                        # --- Determine Preferences ---
                        # Uses 'token_level_rewards' to determine chosen/rejected based on score
                        batch = compute_onlineDPO_pref(batch)  # Adds 'preferences' key

                        # --- Prepare DPO Batch ---
                        dpo_update_batch_proto = None  # Initialize
                        with _timer("prepare_dpo_batch", timing_raw):
                            try:
                                if "preferences" not in batch.batch or batch.batch["preferences"] is None:
                                    raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.")

                                # Check if reference log probs were computed successfully (if needed)
                                if self.use_reference_policy and not ref_log_prob_computed:
                                    raise ValueError("Reference log probs required but failed to compute.")

                                # Check required base keys
                                required_keys = ["input_ids", "attention_mask", "response_mask"]
                                for rk in required_keys:
                                    if rk not in batch.batch or batch.batch[rk] is None:
                                        raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.")

                                preferences_mask = batch.batch["preferences"]  # Shape [batch_size * n]
                                not_preferences_mask = ~preferences_mask

                                # Gather Chosen/Rejected Base Tensors
                                chosen_input_ids = batch.batch["input_ids"][preferences_mask]
                                chosen_attention_mask = batch.batch["attention_mask"][preferences_mask]
                                rejected_input_ids = batch.batch["input_ids"][not_preferences_mask]
                                rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask]
                                chosen_position_ids = (
                                    batch.batch.get("position_ids")[preferences_mask]
                                    if "position_ids" in batch.batch
                                    else None
                                )
                                rejected_position_ids = (
                                    batch.batch.get("position_ids")[not_preferences_mask]
                                    if "position_ids" in batch.batch
                                    else None
                                )

                                # Create Labels
                                print("WARNING: Creating DPO labels using configured max_prompt_length...")
                                prompt_len = self.config.data.max_prompt_length
                                chosen_labels = chosen_input_ids.clone()
                                chosen_labels[:, :prompt_len] = -100
                                rejected_labels = rejected_input_ids.clone()
                                rejected_labels[:, :prompt_len] = -100

                                # Calculate and Gather Reference Log Probs (Sequence Level)
                                if self.use_reference_policy:
                                    ref_log_prob_tensor = batch.batch["ref_log_prob"]  # Token level [bsz * n, seq_len]
                                    response_mask_full = batch.batch[
                                        "response_mask"
                                    ]  # Response mask [bsz * n, seq_len]
                                    ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(
                                        dim=-1
                                    )  # Sequence level [bsz * n]
                                    reference_chosen_logps = ref_sequence_logps[preferences_mask]
                                    reference_rejected_logps = ref_sequence_logps[not_preferences_mask]
                                else:
                                    # If not using external ref, DPO needs ActorAsRef logic in dp_actor
                                    # We won't add the keys here, dp_actor will handle it (or fail if not modified)
                                    print(
                                        "Info: Not adding explicit reference logps to DPO batch "
                                        "(use_reference_policy=False)."
                                    )
                                    reference_chosen_logps = None  # Explicitly None
                                    reference_rejected_logps = None

                                # Package Tensors
                                dpo_tensors = {
                                    "chosen_input_ids": chosen_input_ids,
                                    "chosen_attention_mask": chosen_attention_mask,
                                    "chosen_labels": chosen_labels,
                                    "rejected_input_ids": rejected_input_ids,
                                    "rejected_attention_mask": rejected_attention_mask,
                                    "rejected_labels": rejected_labels,
                                }
                                # Conditionally add reference logps if computed
                                if reference_chosen_logps is not None:
                                    dpo_tensors["reference_chosen_logps"] = reference_chosen_logps
                                if reference_rejected_logps is not None:
                                    dpo_tensors["reference_rejected_logps"] = reference_rejected_logps
                                # Add position ids if they exist
                                if chosen_position_ids is not None:
                                    dpo_tensors["chosen_position_ids"] = chosen_position_ids
                                if rejected_position_ids is not None:
                                    dpo_tensors["rejected_position_ids"] = rejected_position_ids

                                # Prepare Meta Info
                                dpo_meta = {
                                    "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1),
                                    "dpo_loss_type": OmegaConf.select(
                                        self.config.algorithm, "dpo_loss_type", default="sigmoid"
                                    ),
                                    "dpo_label_smoothing": OmegaConf.select(
                                        self.config.algorithm, "dpo_label_smoothing", default=0.0
                                    ),
                                    "use_reference_policy": self.use_reference_policy,
                                    "reference_free": not self.use_reference_policy,  # False if using external ref
                                    "global_step": self.global_steps,
                                }

                                dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)
                                # print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----")
                                # print(f"  Keys: {list(dpo_update_batch_proto.batch.keys())}")
                                # print(f"  Meta Info: {dpo_meta}")

                            except Exception as e_prep:
                                print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}")
                                traceback.print_exc()
                                dpo_update_batch_proto = None  # Skip update on error

                        # --- Actor Update Step ---
                        actor_output = None
                        if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:
                            with _timer("update_actor", timing_raw):
                                # Pass the batch containing reference log probs (if computed)
                                # The modified update_actor_dpo expects them if reference_free=False
                                actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)
                            if actor_output and "metrics" in actor_output.meta_info:
                                metrics.update(reduce_metrics(actor_output.meta_info["metrics"]))
                        elif dpo_update_batch_proto is None:
                            print(
                                f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error."
                            )

                        # --- Validation and Saving ---
                        test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1)
                        is_last_step = self.global_steps >= self.total_training_steps
                        if (
                            self.val_reward_fn is not None
                            and test_freq > 0
                            and (is_last_step or self.global_steps % test_freq == 0)
                        ):
                            print(f"\nRunning DPO validation at step {self.global_steps}...")
                            val_timing_raw = {}
                            with _timer("testing", val_timing_raw):
                                val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                            if val_metrics:
                                metrics["time/validation_run"] = val_timing_raw.get("testing", 0)
                                metrics.update(val_metrics)
                            else:
                                print("Validation skipped or returned no metrics.")

                        save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
                        if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):
                            print(f"\nSaving DPO checkpoint at step {self.global_steps}...")
                            with _timer("save_checkpoint", timing_raw):
                                self._save_checkpoint()  # Saves actor (and potentially critic if used elsewhere)
                            metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0)

                    # --- End main step timer context ---

                    # --- Metrics calculation AFTER the 'step' timer block ---
                    metrics.update(compute_dpo_data_metrics(batch=batch))  # Use DPO-specific metrics
                    metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                    n_gpus = self.resource_pool_manager.get_n_gpus()
                    if "step" in timing_raw:
                        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                    else:
                        print(
                            f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. "
                            f"Skipping throughput."
                        )

                    step_timer.stop()
                    metrics["time/step"] = step_timer.last

                    # Log metrics
                    log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1)
                    if logger and self.global_steps % log_freq == 0:
                        log_payload = metrics.copy()
                        # Add learning rate to log payload
                        if actor_output and "actor/lr" in metrics:
                            log_payload["actor/lr"] = metrics["actor/lr"]

                        print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}")
                        try:
                            logger.log(data=log_payload, step=self.global_steps)
                        except Exception as e:
                            print(f"Logging failed at step {self.global_steps}: {e}")

                    # Update progress bar
                    postfix_metrics = {
                        k: f"{v:.3f}" if isinstance(v, float) else v
                        for k, v in metrics.items()
                        if isinstance(v, int | float)
                    }
                    progress_bar.set_postfix(postfix_metrics)

                except Exception as step_e:
                    print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!")
                    print(f"Caught Exception: {step_e}")
                    traceback.print_exc()
                    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    step_timer.stop()
                    should_stop = True
                    break

                if is_last_step or should_stop:
                    print(f"Stopping DPO training at step {self.global_steps}.")
                    break

                self.global_steps += 1
                progress_bar.update(1)

            # End of epoch handling
            if hasattr(self.train_dataloader, "reset"):
                try:
                    self.train_dataloader.reset()
                except Exception as e:
                    print(f"Warning: Failed to reset train dataloader state: {e}")
            if should_stop:
                break

        # --- Final cleanup and logging ---
        progress_bar.close()
        final_step = max(0, self.global_steps - 1)
        print(f"Online DPO Training finished at step {final_step}.")
        # Save final checkpoint
        save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
        if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0):
            print(f"Saving final DPO checkpoint at step {final_step}...")
            self._save_checkpoint()

        # Final validation run
        if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False):
            print("Running final validation...")
            last_val_metrics = self._validate()
            if last_val_metrics and logger:
                last_val_metrics["final_validation"] = True
                try:
                    logger.log(data=last_val_metrics, step=final_step)
                except Exception as e:
                    print(f"[Final Val Metrics Log Error]: {e}")

        pprint(f"Final validation metrics: {last_val_metrics}")
        if logger and hasattr(logger, "finish"):
            logger.finish()
        print("Online DPO Training Run Complete.")
