# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from collections import defaultdict
from pprint import pprint
from typing import Optional, Type

import numpy as np
import torch
from omegaconf import OmegaConf
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.metric_utils import (
    compute_throughout_metrics,
    compute_timing_metrics,
    process_benchmark_metrics,
)
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, compute_response_mask
from verl.utils.debug.performance import _timer
from verl.utils.metric import (
    reduce_metrics,
)
from verl.utils.tracking import ValidationGenerationsLogger
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance

WorkerType = Type[Worker]


class RayOSFTTrainer(RayPPOTrainer):
    """
    RayOSFTTrainer is a trainer for online supervised fine-tuning (OSFT) using Ray.
    """

    def __init__(
        self,
        config,
        tokenizer,
        role_worker_mapping: dict[Role, WorkerType],
        resource_pool_manager: ResourcePoolManager,
        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
        processor=None,
        reward_fn=None,
        val_reward_fn=None,
        train_dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        collate_fn=None,
        train_sampler: Optional[Sampler] = None,
        device_name="cuda",
    ):
        """Initialize distributed PPO trainer with Ray backend."""

        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn
        self.use_critic = False

        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
        assert self.hybrid_engine, "Currently, only support hybrid engine"

        if self.hybrid_engine:
            assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reference_policy = False
        self.use_rm = False
        self.ray_worker_group_cls = ray_worker_group_cls
        self.device_name = device_name
        self.validation_generations_logger = ValidationGenerationsLogger()

        self._validate_config()
        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

    def fit(self):
        """
        The training loop of SFT.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the SFT dataflow.
        """

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}
                batch: DataProto = DataProto.from_single_dict(batch_dict)

                # pop those keys for generation
                batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
                non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
                if "multi_modal_data" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("multi_modal_data")
                if "raw_prompt" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("raw_prompt")
                if "tools_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("tools_kwargs")
                gen_batch = batch.pop(
                    batch_keys=batch_keys_to_pop,
                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
                )

                is_last_step = self.global_steps >= self.total_training_steps

                with _timer("step", timing_raw):
                    # generate a batch
                    with _timer("gen", timing_raw):
                        if not self.async_rollout_mode:
                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                        else:
                            self.async_rollout_manager.wake_up()
                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
                            self.async_rollout_manager.sleep()
                        timing_raw.update(gen_batch_output.meta_info["timing"])
                        gen_batch_output.meta_info.pop("timing", None)

                    batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)

                    # repeat to align with repeated responses in rollout
                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)

                    batch.batch["response_mask"] = compute_response_mask(batch)

                    # compute scores using function-based reward 'model'
                    reward_tensor = self.reward_fn(batch)
                    batch.batch["token_level_scores"] = reward_tensor
                    # we do not have adv, therefore we use scores as rewards
                    batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                    # applying rejection sampling
                    if self.config.trainer.rejection_sampling:
                        # in the batch.meta_info

                        # 1. group the batch by prompts (input_ids)
                        # 2. calculate the sentence scores 
                        #    should changable to use sequence length and etc.
                        # 3. sort the sentences by scores
                        # 4. keep the top n sentences

                        assert self.config.trainer.rs_n > 0, f"Rejection sampling requires rs_n > 0, got {self.config.trainer.rs_n}"

                        # 1. Calculate sequence-level scores using the response mask
                        sequence_scores = (batch.batch["token_level_scores"] * batch.batch["response_mask"]).sum(dim=-1)

                        # 2. Group by original prompt
                        n_responses_per_prompt = self.config.actor_rollout_ref.rollout.n
                        original_batch_size = len(batch) // n_responses_per_prompt
                        rs_n = self.config.trainer.rs_n
                        
                        assert n_responses_per_prompt >= rs_n, f"Number of generated responses per prompt ({n_responses_per_prompt}) must be >= rs_n ({rs_n})"

                        # Reshape scores to (original_batch_size, n_responses_per_prompt)
                        grouped_scores = sequence_scores.view(original_batch_size, n_responses_per_prompt)

                        # --- Tie-breaking for zero-score groups ---
                        # Find groups where all scores are 0.
                        all_scores_zero_mask = (grouped_scores == 0).all(dim=1)

                        if all_scores_zero_mask.any():
                            # Calculate response lengths and group them like the scores.
                            response_lengths = batch.batch["response_mask"].sum(dim=-1)
                            grouped_lengths = response_lengths.view(original_batch_size, n_responses_per_prompt)

                            # For tie-breaking, we want the response with a length closest to the median length of the group.
                            # Get only the lengths for the groups that need tie-breaking.
                            tie_break_lengths = grouped_lengths[all_scores_zero_mask].float()

                            # Calculate the median length for each of these groups. keepdim=True for broadcasting.
                            median_lengths, _ = torch.median(tie_break_lengths, dim=1, keepdim=True)

                            # Calculate the absolute difference from the median.
                            abs_diff_from_median = torch.abs(tie_break_lengths - median_lengths)

                            # torch.topk finds the largest values, so we use the negative difference as the score.
                            # A smaller difference (closer to median) results in a larger score.
                            tie_breaker_scores = -abs_diff_from_median

                            # Replace scores with tie-breaker scores only for the zero-score groups.
                            grouped_scores[all_scores_zero_mask] = tie_breaker_scores

                        # 3. Sort and keep top n
                        # Get indices of the top rs_n scores for each prompt
                        _, top_indices_in_group = torch.topk(grouped_scores, k=rs_n, dim=1)

                        # Convert group indices to indices in the flattened full_batch
                        base_indices = torch.arange(original_batch_size, device=top_indices_in_group.device) * n_responses_per_prompt
                        selected_indices = (base_indices.unsqueeze(1) + top_indices_in_group).view(-1)

                        # 4. Filter the batch to keep only the selected samples
                        batch = batch.select_idxs(selected_indices)

                        # --- FIX: Update batch size info in meta_info ---
                        # The total number of samples has changed. We need to inform the actor worker.
                        # The actor's internal batching logic (mini/micro batches) depends on this.
                        # We can store the new, correct batch size in meta_info.
                        # The actor worker should be written to prioritize this value over static config.
                        new_batch_size = len(batch)
                        batch.meta_info["train_batch_size"] = new_batch_size
                        
                        # It's also good practice to update the number of responses per prompt
                        batch.meta_info["n_responses_per_prompt"] = rs_n

                    # Balance the number of valid tokens across DP ranks.
                    # NOTE: This usually changes the order of data in the `batch`,
                    # which won't affect the advantage calculation (since it's based on uid),
                    # but might affect the loss calculation (due to the change of mini-batching).
                    # TODO: Decouple the DP balancing and mini-batching.
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    if self.config.trainer.enable_train_temperature:
                        batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
                    else:
                        batch.meta_info["temperature"] = 1.0

                    # update actor (core part)
                    with _timer("update_actor", timing_raw):
                        batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
                        actor_output = self.actor_rollout_wg.update_actor(batch)
                    actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                    metrics.update(actor_output_metrics)

                    # Log rollout generations if enabled
                    rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
                    if rollout_data_dir:
                        with _timer("dump_rollout_generations", timing_raw):
                            # print(batch.batch.keys())
                            inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
                            outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
                            # scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()

                            # --- Scores and Rewards (from reward_fn) ---
                            if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None:
                                sequence_score = batch.batch["token_level_scores"].sum(-1)
                                scores = sequence_score.cpu().tolist()
                                metrics.update(
                                    {
                                        "reward/score/mean": torch.mean(sequence_score).item(),
                                        "reward/score/max": torch.max(sequence_score).item(),
                                        "reward/score/min": torch.min(sequence_score).item(),
                                    }
                                )
                            else:
                                print("DEBUG dump_rollout_generations: 'token_level_scores' not found.")
                                scores = [0 for _ in range(len(inputs))]  # placeholder, since we don't have scores in OSFT

                            response_lengths = batch.batch["response_mask"].sum(dim=-1).cpu().tolist()

                            self._dump_generations(
                                inputs=inputs,
                                outputs=outputs,
                                scores=scores,
                                reward_extra_infos_dict={"response_lengths": response_lengths},
                                dump_path=rollout_data_dir,
                            )

                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with _timer("testing", timing_raw):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                        with _timer("save_checkpoint", timing_raw):
                            self._save_checkpoint()

                # training metrics
                metrics.update(
                    {
                        "training/global_step": self.global_steps,
                        "training/epoch": epoch,
                    }
                )
                # collect metrics

                # no reward_fn, so no reward metrics from compute_data_metrics
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                progress_bar.update(1)
                self.global_steps += 1
                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return

    def _validate(self):
        data_source_lst = []
        reward_extra_infos_dict: dict[str, list] = defaultdict(list)

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        # Lists to collect lengths and sources for metric calculation
        all_response_lengths = []
        all_data_sources = []

        # # Collect padded math500 samples for logits analysis ---
        # math500_indices = []
        # math500_prompts = set()  # Track unique prompts
        # max_math500_samples = 16

        # Collect samples for logits analysis for each data source
        logits_analysis_samples: dict[str, DataProto] = {}
        unique_prompts_per_source: dict[str, set] = defaultdict(set)
        max_samples_per_source = 16

        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True)

            # we only do validation on rule-based rm
            if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
                return {}

            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            # TODO: Can we keep special tokens except for padding tokens?
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
            non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
            if "multi_modal_data" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("multi_modal_data")
            if "raw_prompt" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("raw_prompt")
            if "tools_kwargs" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("tools_kwargs")
            test_gen_batch = test_batch.pop(
                batch_keys=batch_keys_to_pop,
                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
            )

            test_gen_batch.meta_info = {
                "eos_token_id": self.tokenizer.eos_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
                "recompute_log_prob": False,
                "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                "validate": True,
            }
            print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

            # pad to be divisible by dp_size
            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            if not self.async_rollout_mode:
                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
            else:
                self.async_rollout_manager.wake_up()
                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
                self.async_rollout_manager.sleep()

        # # Collect math500 samples from padded output (before unpadding)
        # if len(math500_indices) < max_math500_samples:
        #     data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(test_batch))
            
        #     for i, ds in enumerate(data_sources):
        #         if ds == "math500" and len(math500_indices) < max_math500_samples:
        #             current_prompt = input_texts[i]
        #             # Check for unique prompts
        #             if current_prompt not in math500_prompts:
        #                 math500_indices.append(i)
        #                 math500_prompts.add(current_prompt)
        #                 print(f"Collected math500 index {len(math500_indices)}/{max_math500_samples} for logits analysis")

            # Collect samples for logits analysis from the CURRENT padded batch
            data_sources_in_batch = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(test_batch))
            indices_to_collect_by_source = defaultdict(list)

            for i, ds in enumerate(data_sources_in_batch):
                num_collected = len(logits_analysis_samples.get(ds, [])) if ds in logits_analysis_samples else 0
                if num_collected < max_samples_per_source:
                    current_prompt = input_texts[i]
                    if current_prompt not in unique_prompts_per_source[ds]:
                        indices_to_collect_by_source[ds].append(i)
                        unique_prompts_per_source[ds].add(current_prompt)

            for ds, indices in indices_to_collect_by_source.items():
                num_collected = len(logits_analysis_samples.get(ds, [])) if ds in logits_analysis_samples else 0
                num_needed = max_samples_per_source - num_collected
                indices_to_take = indices[:num_needed]

                if not indices_to_take:
                    continue

                print(f"Collecting {len(indices_to_take)} samples for data source '{ds}' for logits analysis...")
                new_samples_to_add = test_output_gen_batch_padded.select_idxs(indices_to_take)

                if ds not in logits_analysis_samples:
                    logits_analysis_samples[ds] = new_samples_to_add
                else:
                    logits_analysis_samples[ds] = logits_analysis_samples[ds].union(new_samples_to_add)


            # unpad
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print("validation generation end")

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # Calculate and store response lengths and their data sources
            test_batch.batch["response_mask"] = compute_response_mask(test_batch)
            response_lengths = test_batch.batch["response_mask"].sum(dim=-1)
            data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(response_lengths))
            
            all_response_lengths.append(response_lengths.cpu())
            all_data_sources.extend(data_sources)

            # evaluate using reward_function
            result = self.val_reward_fn(test_batch, return_dict=True)
            reward_tensor = result["reward_tensor"]
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_extra_infos_dict["reward"].extend(scores)
            reward_extra_infos_dict["response_lengths"].extend(response_lengths.cpu().tolist())
            if "reward_extra_info" in result:
                for key, lst in result["reward_extra_info"].items():
                    reward_extra_infos_dict[key].extend(lst)

            data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

        # # Process collected math500 samples for logits analysis ---
        # logits_analysis_metrics = {}
        # if math500_indices:
        #     print(f"Processing {len(math500_indices)} math500 samples for logits analysis...")

        #     # Select math500 samples from the last padded batch in one operation
        #     math500_batch_padded = test_output_gen_batch_padded.select_idxs(math500_indices)

        #     # Compute logits, norm, and entropy for the selected math500 batch
        #     logits_output = self.actor_rollout_wg.compute_logits_norm_n_entropy(math500_batch_padded)

        #     # Extract metrics from the logits computation
        #     if "logits_l2_norms" in logits_output.batch:
        #         logits_l2_norms = logits_output.batch["logits_l2_norms"]
        #         logits_maxs = logits_output.batch["logits_maxs"]
        #         logits_dispersions = logits_output.batch["logits_dispersions"]
        #         logits_analysis_metrics["val/16-math500/logits/avg_norm"] = logits_l2_norms.mean().item()
        #         logits_analysis_metrics["val/16-math500/logits/avg_max"] = logits_maxs.mean().item()
        #         logits_analysis_metrics["val/16-math500/logits/avg_dispersion"] = logits_dispersions.mean().item()

        #     if "entropys" in logits_output.batch:
        #         entropys = logits_output.batch["entropys"]
        #         logits_analysis_metrics["val/16-math500/entropy/avg"] = entropys.mean().item()
        #         logits_analysis_metrics["val/16-math500/entropy/std"] = entropys.std().item()
        #         logits_analysis_metrics["val/16-math500/entropy/max"] = entropys.max().item()
        #         logits_analysis_metrics["val/16-math500/entropy/min"] = entropys.min().item()

        #     if "log_probs" in logits_output.batch:
        #         log_probs = logits_output.batch["log_probs"]
        #         log_probs_dispersions = logits_output.batch["log_probs_dispersions"]
        #         logits_analysis_metrics["val/16-math500/log_probs/avg"] = log_probs.mean().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/std"] = log_probs.std().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/max"] = log_probs.max().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/min"] = log_probs.min().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/avg_dispersion"] = log_probs_dispersions.mean().item()

        # Process collected samples for logits analysis for each data source
        logits_analysis_metrics = {}
        for data_source, collected_samples in logits_analysis_samples.items():
            if len(collected_samples) > 0:
                print(f"Processing {len(collected_samples)} samples for data source '{data_source}' for logits analysis...")

                logits_output = self.actor_rollout_wg.compute_logits_norm_n_entropy(collected_samples)
                
                # Dynamically generate metric names
                if "logits_l2_norms" in logits_output.batch:
                    logits_l2_norms = logits_output.batch["logits_l2_norms"]
                    logits_maxs = logits_output.batch["logits_maxs"]
                    logits_dispersions = logits_output.batch["logits_dispersions"]
                    logits_variances = logits_output.batch["logits_variances"]
                    logits_skewnesses = logits_output.batch["logits_skewnesses"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_norm"] = logits_l2_norms.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_max"] = logits_maxs.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_dispersion"] = logits_dispersions.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_variance"] = logits_variances.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/std_variance"] = logits_variances.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_skewness"] = logits_skewnesses.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/std_skewness"] = logits_skewnesses.std().item()

                if "global_logits_skewness" in logits_output.meta_info:
                    global_skew = logits_output.meta_info["global_logits_skewness"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/global_skewness"] = global_skew.item()
                    
                if "entropys" in logits_output.batch:
                    entropys = logits_output.batch["entropys"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/avg"] = entropys.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/std"] = entropys.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/max"] = entropys.max().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/min"] = entropys.min().item()

                if "log_probs" in logits_output.batch:
                    log_probs = logits_output.batch["log_probs"]
                    log_probs_dispersions = logits_output.batch["log_probs_dispersions"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/avg"] = log_probs.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/std"] = log_probs.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/max"] = log_probs.max().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/min"] = log_probs.min().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/avg_dispersion"] = log_probs_dispersions.mean().item()
                    true_perplexity = torch.exp(-log_probs.mean()).item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/perplexity"] = true_perplexity

        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        # dump generations
        val_data_dir = self.config.trainer.get("validation_data_dir", None)
        if val_data_dir:
            self._dump_generations(
                inputs=sample_inputs,
                outputs=sample_outputs,
                scores=sample_scores,
                reward_extra_infos_dict=reward_extra_infos_dict,
                dump_path=val_data_dir,
            )

        for key_info, lst in reward_extra_infos_dict.items():
            assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"

        data_sources = np.concatenate(data_source_lst, axis=0)

        data_src2var2metric2val = process_benchmark_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
        metric_dict = {}
        for data_source, var2metric2val in data_src2var2metric2val.items():
            core_var = "acc" if "acc" in var2metric2val else "reward"
            for var_name, metric2val in var2metric2val.items():
                if not metric2val:
                    continue
                n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
                for metric_name, metric_val in metric2val.items():
                    if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["pass", "avg", "avg_pass"]) and (f"@{n_max}" in metric_name):
                        metric_sec = "val-core"
                    else:
                        metric_sec = "val-aux"
                    pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
                    metric_dict[pfx] = metric_val

        # Calculate and add length-based metrics per data source
        if all_response_lengths:
            all_response_lengths_tensor = torch.cat(all_response_lengths)
            all_data_sources_arr = np.array(all_data_sources)
            max_response_length = self.config.actor_rollout_ref.rollout.response_length

            for data_source in np.unique(all_data_sources_arr):
                mask = all_data_sources_arr == data_source
                source_lengths = all_response_lengths_tensor[mask].float()
                
                if len(source_lengths) > 0:
                    metric_dict[f"val-aux/{data_source}/response/avg_length"] = source_lengths.mean().item()
                    metric_dict[f"val-aux/{data_source}/response/min_length"] = source_lengths.min().item()
                    metric_dict[f"val-aux/{data_source}/response/max_length"] = source_lengths.max().item()
                    
                    cut_off_count = (source_lengths >= max_response_length).sum().item()
                    cut_off_ratio = cut_off_count / len(source_lengths)
                    metric_dict[f"val-aux/{data_source}/response/cut_off_ratio"] = cut_off_ratio

        metric_dict.update(logits_analysis_metrics)

        return metric_dict

    def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
        """
        Reorder the data on single controller such that each dp rank gets similar total tokens.
        Also logs metrics about the generated responses.
        """
        # --- Log response length and cut-off ratio ---
        if "response_mask" in batch.batch:
            response_lengths = batch.batch["response_mask"].sum(dim=-1).float()
            avg_response_length = response_lengths.mean().item()
            
            # Get max response length from config. The path might need adjustment
            # based on your specific configuration file structure.
            max_response_length = self.config.data.get("max_response_length", 3072)
            cut_off_count = (response_lengths >= max_response_length).sum().item()
            
            current_batch_size = len(response_lengths)
            cut_off_ratio = cut_off_count / current_batch_size if current_batch_size > 0 else 0.0
            
            metrics["rollout/avg_response_length"] = avg_response_length
            metrics["rollout/cut_off_ratio"] = cut_off_ratio
            if current_batch_size > 0:
                metrics["rollout/max_response_length"] = response_lengths.max().item()
                metrics["rollout/min_response_length"] = response_lengths.min().item()
                metrics["rollout/median_response_length"] = response_lengths.median().item()

        # --- Original batch balancing logic ---
        attention_mask = batch.batch["attention_mask"]
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True)
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix)
        metrics.update(global_balance_stats)


    def _simple_validate_base(self):
        data_source_lst = []
        reward_extra_infos_dict: dict[str, list] = defaultdict(list)

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        # Lists to collect lengths and sources for metric calculation
        all_response_lengths = []
        all_data_sources = []

        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True)

            # we only do validation on rule-based rm
            if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
                return {}

            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            # TODO: Can we keep special tokens except for padding tokens?
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
            non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
            if "multi_modal_data" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("multi_modal_data")
            if "raw_prompt" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("raw_prompt")
            if "tools_kwargs" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("tools_kwargs")
            test_gen_batch = test_batch.pop(
                batch_keys=batch_keys_to_pop,
                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
            )

            test_gen_batch.meta_info = {
                "eos_token_id": self.tokenizer.eos_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
                "recompute_log_prob": False,
                "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                "validate": True,
            }
            print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

            # pad to be divisible by dp_size
            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            if not self.async_rollout_mode:
                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
            else:
                self.async_rollout_manager.wake_up()
                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
                self.async_rollout_manager.sleep()

            # unpad
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print("validation generation end")

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # Calculate and store response lengths and their data sources
            test_batch.batch["response_mask"] = compute_response_mask(test_batch)
            response_lengths = test_batch.batch["response_mask"].sum(dim=-1)
            data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(response_lengths))
            
            all_response_lengths.append(response_lengths.cpu())
            all_data_sources.extend(data_sources)

            # evaluate using reward_function
            result = self.val_reward_fn(test_batch, return_dict=True)
            reward_tensor = result["reward_tensor"]
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_extra_infos_dict["reward"].extend(scores)
            reward_extra_infos_dict["response_lengths"].extend(response_lengths.cpu().tolist())
            if "reward_extra_info" in result:
                for key, lst in result["reward_extra_info"].items():
                    reward_extra_infos_dict[key].extend(lst)

            data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        # dump generations
        val_data_dir = self.config.trainer.get("validation_data_dir", None)
        if val_data_dir:
            self._dump_generations(
                inputs=sample_inputs,
                outputs=sample_outputs,
                scores=sample_scores,
                reward_extra_infos_dict=reward_extra_infos_dict,
                dump_path=val_data_dir,
            )

        return {}


    def _validate_base(self):
        """
        The training loop of SFT.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the SFT dataflow.
        """

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None:
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)

        val_data_dir = self.config.trainer.get("validation_data_dir", None)
        if val_data_dir:
            import json
            json.dump(val_metrics, open(f"{val_data_dir}/the_sub_metric.json", "w"), indent=4)

        if "swanlab" in logger.logger:
            logger.logger["swanlab"].finish()
        return val_metrics