# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint

import numpy as np
import torch
from tqdm import tqdm

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.trainer.ppo.metric_utils import (
    compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
    reduce_metrics,
    process_benchmark_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask


class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        timing_raw = defaultdict(float)
        batch = None
        num_prompt_in_batch = 0
        num_gen_batches = 0
        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                metrics = {}

                new_batch: DataProto = DataProto.from_single_dict(batch_dict)
                num_gen_batches += 1
                # pop those keys for generation
                if "multi_modal_data" in new_batch.non_tensor_batch.keys():
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
                    )
                else:
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"],
                    )

                is_last_step = self.global_steps >= self.total_training_steps

                with _timer("step", timing_raw):
                    # generate a batch
                    with _timer("gen", timing_raw):
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                        timing_raw.update(gen_batch_output.meta_info["timing"])
                        gen_batch_output.meta_info.pop("timing", None)

                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with _timer("gen_max", timing_raw):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            new_batch = new_batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(new_batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            new_batch.batch["reward_baselines"] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
                    # repeat to align with repeated responses in rollout
                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    new_batch = new_batch.union(gen_batch_output)

                    with _timer("reward", timing_raw):
                        # compute scores. Support both model and function-based.
                        # We first compute the scores using reward model. Then, we call reward_fn to combine
                        # the results from reward model and rule-based results.
                        if self.use_rm:
                            # we first compute reward model score
                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)
                            new_batch = new_batch.union(reward_tensor)

                        # we combine with rule-based rm
                        reward_extra_infos_dict: dict[str, list]
                        try:
                            reward_result = self.reward_fn(new_batch, return_dict=True)
                            reward_tensor = reward_result["reward_tensor"]
                            reward_extra_infos_dict = reward_result["reward_extra_info"]
                        except Exception as e:
                            print(f"Error in reward_fn: {e}")
                            reward_tensor = self.reward_fn(new_batch)
                            reward_extra_infos_dict = {}

                        new_batch.batch["token_level_scores"] = reward_tensor

                        if reward_extra_infos_dict:
                            new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        # compute rewards. apply_kl_penalty if available
                        if self.config.algorithm.use_kl_in_reward:
                            new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)  # TODO: This will be cleared if we use multiple genenration batches
                        else:
                            new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

                    if not self.config.algorithm.filter_groups.enable:
                        batch = new_batch
                    else:  # NOTE: When prompts after filtering is less than train batch size,
                        # we skip to the next generation batch
                        metric_name = self.config.algorithm.filter_groups.metric
                        if metric_name == "seq_final_reward":
                            # Turn to numpy for easier filtering
                            new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
                        elif metric_name == "seq_reward":
                            new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()

                        # Collect the sequence reward for each trajectory
                        prompt_uid2metric_vals = defaultdict(list)
                        for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
                            prompt_uid2metric_vals[uid].append(metric_val)

                        prompt_uid2metric_std = {}
                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

                        kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
                        num_prompt_in_batch += len(kept_prompt_uids)

                        kept_traj_idxs = []
                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
                            if traj_from_prompt_uid in kept_prompt_uids:
                                kept_traj_idxs.append(idx)

                        new_batch = new_batch[kept_traj_idxs]
                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])

                        prompt_bsz = self.config.data.train_batch_size
                        if num_prompt_in_batch < prompt_bsz:
                            print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                                print(f"{num_gen_batches=}. Keep generating...")
                                progress_bar.update(1)
                                continue
                            else:
                                raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
                        else:
                            # Align the batch
                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
                            batch = batch[:traj_bsz]

                    # === Updating ===

                    batch.batch["response_mask"] = compute_response_mask(batch)

                    # Balance the number of valid tokens across DP ranks.
                    # NOTE: This usually changes the order of data in the `batch`,
                    # which won't affect the advantage calculation (since it's based on uid),
                    # but might affect the loss calculation (due to the change of mini-batching).
                    # TODO: Decouple the DP balancing and mini-batching.
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    # recompute old_log_probs
                    with _timer("old_log_prob", timing_raw):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        entropys = old_log_prob.batch["entropys"]
                        response_masks = batch.batch["response_mask"]
                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                        old_log_prob_metrics = {"actor/old_log_prob_entropy": entropy_agg.detach().item()}
                        metrics.update(old_log_prob_metrics)
                        old_log_prob.batch.pop("entropys")
                        batch = batch.union(old_log_prob)

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with _timer("ref", timing_raw):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # compute values
                    if self.use_critic:
                        with _timer("values", timing_raw):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with _timer("adv", timing_raw):
                        # compute advantages, executed on the driver process
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                        )

                    # update critic
                    if self.use_critic:
                        with _timer("update_critic", timing_raw):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    # implement critic warmup
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # update actor
                        with _timer("update_actor", timing_raw):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with _timer("testing", timing_raw):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                        with _timer("save_checkpoint", timing_raw):
                            self._save_checkpoint()

                # collect metrics
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                timing_raw = defaultdict(float)  # clear timing

                metrics["train/num_gen_batches"] = num_gen_batches
                batch = None
                num_prompt_in_batch = 0
                num_gen_batches = 0

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return

                progress_bar.update(1)
                self.global_steps += 1


    def _validate(self):
        data_source_lst = []
        reward_extra_infos_dict: dict[str, list] = defaultdict(list)

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        # Lists to collect lengths and sources for metric calculation
        all_response_lengths = []
        all_data_sources = []

        # # Collect padded math500 samples for logits analysis ---
        # math500_indices = []
        # math500_prompts = set()  # Track unique prompts
        # max_math500_samples = 16

        # Collect samples for logits analysis for each data source
        logits_analysis_samples: dict[str, DataProto] = {}
        unique_prompts_per_source: dict[str, set] = defaultdict(set)
        max_samples_per_source = 16

        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True)

            # we only do validation on rule-based rm
            if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
                return {}

            # Store original inputs
            input_ids = test_batch.batch["input_ids"]
            # TODO: Can we keep special tokens except for padding tokens?
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
            non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
            if "multi_modal_data" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("multi_modal_data")
            if "raw_prompt" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("raw_prompt")
            if "tools_kwargs" in test_batch.non_tensor_batch:
                non_tensor_batch_keys_to_pop.append("tools_kwargs")
            test_gen_batch = test_batch.pop(
                batch_keys=batch_keys_to_pop,
                non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
            )

            test_gen_batch.meta_info = {
                "eos_token_id": self.tokenizer.eos_token_id,
                "pad_token_id": self.tokenizer.pad_token_id,
                "recompute_log_prob": False,
                "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                "validate": True,
            }
            print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

            # pad to be divisible by dp_size
            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            if not self.async_rollout_mode:
                test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
            else:
                self.async_rollout_manager.wake_up()
                test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
                self.async_rollout_manager.sleep()

        # # Collect math500 samples from padded output (before unpadding)
        # if len(math500_indices) < max_math500_samples:
        #     data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(test_batch))
            
        #     for i, ds in enumerate(data_sources):
        #         if ds == "math500" and len(math500_indices) < max_math500_samples:
        #             current_prompt = input_texts[i]
        #             # Check for unique prompts
        #             if current_prompt not in math500_prompts:
        #                 math500_indices.append(i)
        #                 math500_prompts.add(current_prompt)
        #                 print(f"Collected math500 index {len(math500_indices)}/{max_math500_samples} for logits analysis")

            # Collect samples for logits analysis from the CURRENT padded batch
            data_sources_in_batch = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(test_batch))
            indices_to_collect_by_source = defaultdict(list)

            for i, ds in enumerate(data_sources_in_batch):
                num_collected = len(logits_analysis_samples.get(ds, [])) if ds in logits_analysis_samples else 0
                if num_collected < max_samples_per_source:
                    current_prompt = input_texts[i]
                    if current_prompt not in unique_prompts_per_source[ds]:
                        indices_to_collect_by_source[ds].append(i)
                        unique_prompts_per_source[ds].add(current_prompt)

            for ds, indices in indices_to_collect_by_source.items():
                num_collected = len(logits_analysis_samples.get(ds, [])) if ds in logits_analysis_samples else 0
                num_needed = max_samples_per_source - num_collected
                indices_to_take = indices[:num_needed]

                if not indices_to_take:
                    continue

                print(f"Collecting {len(indices_to_take)} samples for data source '{ds}' for logits analysis...")
                new_samples_to_add = test_output_gen_batch_padded.select_idxs(indices_to_take)

                if ds not in logits_analysis_samples:
                    logits_analysis_samples[ds] = new_samples_to_add
                else:
                    logits_analysis_samples[ds] = logits_analysis_samples[ds].union(new_samples_to_add)

            # unpad
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print("validation generation end")

            # Store generated outputs
            output_ids = test_output_gen_batch.batch["responses"]
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # Calculate and store response lengths and their data sources
            test_batch.batch["response_mask"] = compute_response_mask(test_batch)
            response_lengths = test_batch.batch["response_mask"].sum(dim=-1)
            data_sources = test_batch.non_tensor_batch.get("data_source", ["unknown"] * len(response_lengths))
            
            all_response_lengths.append(response_lengths.cpu())
            all_data_sources.extend(data_sources)

            # evaluate using reward_function
            result = self.val_reward_fn(test_batch, return_dict=True)
            reward_tensor = result["reward_DAPO_for_OSFT_tensor"]
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)

            reward_extra_infos_dict["reward"].extend(scores)
            reward_extra_infos_dict["response_lengths"].extend(response_lengths.cpu().tolist())
            # if "reward_extra_info" in result:
            #     for key, lst in result["reward_extra_info"].items():
            #         reward_extra_infos_dict[key].extend(lst)

            data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

        # # Process collected math500 samples for logits analysis ---
        # logits_analysis_metrics = {}
        # if math500_indices:
        #     print(f"Processing {len(math500_indices)} math500 samples for logits analysis...")

        #     # Select math500 samples from the last padded batch in one operation
        #     math500_batch_padded = test_output_gen_batch_padded.select_idxs(math500_indices)

        #     # Compute logits, norm, and entropy for the selected math500 batch
        #     logits_output = self.actor_rollout_wg.compute_logits_norm_n_entropy(math500_batch_padded)

        #     # Extract metrics from the logits computation
        #     if "logits_l2_norms" in logits_output.batch:
        #         logits_l2_norms = logits_output.batch["logits_l2_norms"]
        #         logits_maxs = logits_output.batch["logits_maxs"]
        #         logits_dispersions = logits_output.batch["logits_dispersions"]
        #         logits_analysis_metrics["val/16-math500/logits/avg_norm"] = logits_l2_norms.mean().item()
        #         logits_analysis_metrics["val/16-math500/logits/avg_max"] = logits_maxs.mean().item()
        #         logits_analysis_metrics["val/16-math500/logits/avg_dispersion"] = logits_dispersions.mean().item()

        #     if "entropys" in logits_output.batch:
        #         entropys = logits_output.batch["entropys"]
        #         logits_analysis_metrics["val/16-math500/entropy/avg"] = entropys.mean().item()
        #         logits_analysis_metrics["val/16-math500/entropy/std"] = entropys.std().item()
        #         logits_analysis_metrics["val/16-math500/entropy/max"] = entropys.max().item()
        #         logits_analysis_metrics["val/16-math500/entropy/min"] = entropys.min().item()

        #     if "log_probs" in logits_output.batch:
        #         log_probs = logits_output.batch["log_probs"]
        #         log_probs_dispersions = logits_output.batch["log_probs_dispersions"]
        #         logits_analysis_metrics["val/16-math500/log_probs/avg"] = log_probs.mean().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/std"] = log_probs.std().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/max"] = log_probs.max().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/min"] = log_probs.min().item()
        #         logits_analysis_metrics["val/16-math500/log_probs/avg_dispersion"] = log_probs_dispersions.mean().item()

        # Process collected samples for logits analysis for each data source
        logits_analysis_metrics = {}
        for data_source, collected_samples in logits_analysis_samples.items():
            if len(collected_samples) > 0:
                print(f"Processing {len(collected_samples)} samples for data source '{data_source}' for logits analysis...")

                logits_output = self.actor_rollout_wg.compute_logits_norm_n_entropy(collected_samples)
                
                # Dynamically generate metric names
                if "logits_l2_norms" in logits_output.batch:
                    logits_l2_norms = logits_output.batch["logits_l2_norms"]
                    logits_maxs = logits_output.batch["logits_maxs"]
                    logits_dispersions = logits_output.batch["logits_dispersions"]
                    logits_variances = logits_output.batch["logits_variances"]
                    logits_skewnesses = logits_output.batch["logits_skewnesses"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_norm"] = logits_l2_norms.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_max"] = logits_maxs.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_dispersion"] = logits_dispersions.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_variance"] = logits_variances.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/std_variance"] = logits_variances.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/avg_skewness"] = logits_skewnesses.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/std_skewness"] = logits_skewnesses.std().item()

                if "global_logits_skewness" in logits_output.meta_info:
                    global_skew = logits_output.meta_info["global_logits_skewness"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/logits/global_skewness"] = global_skew.item()
                    
                if "entropys" in logits_output.batch:
                    entropys = logits_output.batch["entropys"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/avg"] = entropys.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/std"] = entropys.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/max"] = entropys.max().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/entropy/min"] = entropys.min().item()

                if "log_probs" in logits_output.batch:
                    log_probs = logits_output.batch["log_probs"]
                    log_probs_dispersions = logits_output.batch["log_probs_dispersions"]
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/avg"] = log_probs.mean().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/std"] = log_probs.std().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/max"] = log_probs.max().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/min"] = log_probs.min().item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/log_probs/avg_dispersion"] = log_probs_dispersions.mean().item()
                    true_perplexity = torch.exp(-log_probs.mean()).item()
                    logits_analysis_metrics[f"val/{max_samples_per_source}-{data_source}/perplexity"] = true_perplexity

        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        # dump generations
        val_data_dir = self.config.trainer.get("validation_data_dir", None)
        if val_data_dir:
            self._dump_generations(
                inputs=sample_inputs,
                outputs=sample_outputs,
                scores=sample_scores,
                reward_extra_infos_dict=reward_extra_infos_dict,
                dump_path=val_data_dir,
            )

        for key_info, lst in reward_extra_infos_dict.items():
            assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"

        data_sources = np.concatenate(data_source_lst, axis=0)

        # data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
        # metric_dict = {}
        # for data_source, var2metric2val in data_src2var2metric2val.items():
        #     core_var = "acc" if "acc" in var2metric2val else "reward"
        #     for var_name, metric2val in var2metric2val.items():
        #         n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
        #         for metric_name, metric_val in metric2val.items():
        #             if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name):
        #                 metric_sec = "val-core"
        #             else:
        #                 metric_sec = "val-aux"
        #             pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
        #             metric_dict[pfx] = metric_val

        # pass@k version
        data_src2var2metric2val = process_benchmark_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
        metric_dict = {}
        for data_source, var2metric2val in data_src2var2metric2val.items():
            core_var = "acc" if "acc" in var2metric2val else "reward"
            for var_name, metric2val in var2metric2val.items():
                if not metric2val:
                    continue
                n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
                for metric_name, metric_val in metric2val.items():
                    if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["pass", "avg", "avg_pass"]) and (f"@{n_max}" in metric_name):
                        metric_sec = "val-core"
                    else:
                        metric_sec = "val-aux"
                    pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
                    metric_dict[pfx] = metric_val

        # Calculate and add length-based metrics per data source
        if all_response_lengths:
            all_response_lengths_tensor = torch.cat(all_response_lengths)
            all_data_sources_arr = np.array(all_data_sources)
            max_response_length = self.config.actor_rollout_ref.rollout.response_length

            for data_source in np.unique(all_data_sources_arr):
                mask = all_data_sources_arr == data_source
                source_lengths = all_response_lengths_tensor[mask].float()
                
                if len(source_lengths) > 0:
                    metric_dict[f"val-aux/{data_source}/response/avg_length"] = source_lengths.mean().item()
                    metric_dict[f"val-aux/{data_source}/response/min_length"] = source_lengths.min().item()
                    metric_dict[f"val-aux/{data_source}/response/max_length"] = source_lengths.max().item()
                    
                    cut_off_count = (source_lengths >= max_response_length).sum().item()
                    cut_off_ratio = cut_off_count / len(source_lengths)
                    metric_dict[f"val-aux/{data_source}/response/cut_off_ratio"] = cut_off_ratio

        metric_dict.update(logits_analysis_metrics)

        return metric_dict