# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint
import os

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
    # compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
    reduce_metrics,
)
from recipe.length_src.ray_trainer import AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.utils.debug import marked_timer
from verl.utils.model import compute_position_id_with_mask
from verl.trainer.ppo.metric_utils import _compute_response_info
from typing import Dict, Optional, Type, Any


class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        self.best_avg_score = float('-inf')
        self.best_four_sets = float('-inf')
        self.best_four_model_step = 0
        self.best_model_step = 0

        # load checkpoint before doing anything
        self._load_checkpoint()
        self._load_best_model_info()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        timing_raw = defaultdict(float)
        batch = None
        num_prompt_in_batch = 0
        num_gen_batches = 0
        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                do_profile = self.global_steps in (self.config.trainer.profile_steps or [])
                if do_profile:
                    self.actor_rollout_wg.start_profile()
                    if self.use_reference_policy:
                        self.ref_policy_wg.start_profile()
                    if self.use_critic:
                        self.critic_wg.start_profile()
                    if self.use_rm:
                        self.rm_wg.start_profile()

                metrics = {}

                new_batch: DataProto = DataProto.from_single_dict(batch_dict)
                num_gen_batches += 1
                # pop those keys for generation
                if "multi_modal_data" in new_batch.non_tensor_batch.keys():
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
                    )
                else:
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"],
                    )

                is_last_step = self.global_steps >= self.total_training_steps
                

                print("original input",  self.tokenizer.decode(gen_batch.batch["input_ids"][0], skip_special_tokens=True))

                
                with marked_timer("step", timing_raw):
                    # generate a batch
                    with marked_timer("gen", timing_raw, "red"):
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                        timing_raw.update(gen_batch_output.meta_info["timing"])
                        gen_batch_output.meta_info.pop("timing", None)
                    print("output shape: ", gen_batch_output.batch["responses"].shape, "output", self.tokenizer.decode(gen_batch_output.batch["responses"][0], skip_special_tokens=True))
                    
                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with marked_timer("gen_max", timing_raw, "red"):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            new_batch = new_batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(new_batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            new_batch.batch["reward_baselines"] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
                    # repeat to align with repeated responses in rollout
                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    new_batch = new_batch.union(gen_batch_output)
                    
                    # calculate chunk-wise ground truth prob
                    try:
                        chunk_size = int(getattr(getattr(self.config, "trainer", None), "gt_prob_chunk_size", 512))
                    except Exception:
                        chunk_size = 512

                    # ragged list: one list per sample/trajectory in this batch
                    gt_suffix_logprob_sums_per_sample: list[list[float]] = [[] for _ in range(len(new_batch))]

                    try:
                        # 1. 准备基础数据
                        prompt_len = int(new_batch.batch["prompts"].shape[-1])
                        original_input_ids = new_batch.batch["input_ids"]  # [B, Seq_Len]
                        attention_mask_full = new_batch.batch["attention_mask"]
                        responses_full = new_batch.batch["responses"]
                        
                        # 计算每个样本生成的有效 response 长度
                        valid_resp_lens = attention_mask_full[:, prompt_len:].sum(dim=-1).to(torch.long)

                        pad_token_id = self.tokenizer.pad_token_id

                        # 临时列表存储构建好的数据
                        entry_full_inputs: list[torch.Tensor] = []  # 存完整序列: Prompt + Chunk + Suffix
                        entry_suffixes: list[torch.Tensor] = []     # 存 Suffix (用于标记计算区域)
                        entry_map_back: list[int] = []              # 记录属于哪个 sample

                        for i in range(len(new_batch)):
                            valid_len = int(valid_resp_lens[i].item())
                            if valid_len <= 0:
                                continue

                            # --- [关键步骤 A] 提取无 Padding 的 Prompt ---
                            # 原始 input_ids 是 Left Padded 的，我们需要把非 Pad 的 Prompt token 取出来
                            curr_prompt_ids = original_input_ids[i, :prompt_len]
                            curr_prompt_ids = curr_prompt_ids[curr_prompt_ids != pad_token_id]

                            # 获取当前生成的有效 response token
                            valid_resp_ids = responses_full[i, :valid_len].to(torch.long)
                           
                            # 构造 Ground Truth Suffix
                            ground_truth = new_batch[i].non_tensor_batch["reward_model"]["ground_truth"]
                            gt_str = str(ground_truth).strip()
                            gt_boxed = gt_str if "\\boxed" in gt_str else ("\\boxed{" + gt_str + "}")
                            suffix_text = " So the final answer is " + gt_boxed
                            suffix_ids_list = self.tokenizer.encode(suffix_text, add_special_tokens=False)
                           
                            suffix_ids = torch.tensor(suffix_ids_list, dtype=torch.long)

                            # --- [关键步骤 B] 按 Chunk 循环拼接 ---
                            # num_chunks 计算逻辑保持不变
                            num_chunks = int((int(valid_resp_ids.numel()) + chunk_size - 1) // chunk_size)
                            
                            for chunk_idx in range(num_chunks):
                                start = chunk_idx * chunk_size
                                end = min((chunk_idx + 1) * chunk_size, int(valid_resp_ids.numel()))
                                
                                # 获取累积的 response chunk (从头开始到当前 end)
                                chunk_resp_ids = valid_resp_ids[:end]
                                if chunk_resp_ids.numel() == 0:
                                    continue

                                # 拼接： [Prompt] + [Current Response Chunk] + [GT Suffix]
                                # 这样模型能看到题目、当前生成的推理过程，然后预测标准答案
                                full_seq = torch.cat([curr_prompt_ids, chunk_resp_ids, suffix_ids], dim=0)

                                entry_full_inputs.append(full_seq)
                                entry_suffixes.append(suffix_ids)
                                entry_map_back.append(i)

                        if entry_full_inputs:
                            num_entries = len(entry_full_inputs)
                            
                            # 计算 Batch 内的最大长度
                            max_seq_len = max(int(t.numel()) for t in entry_full_inputs)
                            max_suffix_len = max(int(t.numel()) for t in entry_suffixes)

                            # 初始化 Tensors (在 device 上)
                            # 注意：这里我们使用 Left Padding，所以初始化全为 PAD

                            input_ids_padded = torch.full(
                                (num_entries, max_seq_len),
                                fill_value=int(pad_token_id),
                                dtype=torch.long
                            )
                            attention_mask_padded = torch.zeros_like(input_ids_padded, dtype=torch.long)
                            
                            # Responses tensor 通常只需右对齐填充 Suffix，框架会根据长度或 mask 计算
                            responses_padded = torch.full(
                                (num_entries, max_suffix_len),
                                fill_value=int(pad_token_id),
                                dtype=torch.long
                            )
                            # 用于手动 mask 掉 suffix 中的 pad 部分 (如果 output 包含 pad 的 logprob)
                            resp_mask_padded = torch.zeros((num_entries, max_suffix_len), dtype=torch.float32)

                            # --- [关键步骤 C] 执行 Left Padding ---
                            # 确保 Suffix 位于 input_ids 的最右侧，以便 compute_log_prob 正确切片
                            for j in range(num_entries):
                                seq = entry_full_inputs[j]
                                suffix = entry_suffixes[j]
                                
                                seq_len = int(seq.numel())
                                suffix_len = int(suffix.numel())

                                # Left Padding input_ids: 把数据填在最后 seq_len 个位置
                                input_ids_padded[j, -seq_len:] = seq
                                attention_mask_padded[j, -seq_len:] = 1

                                # Right Padding responses: 仅放入 Suffix
                                responses_padded[j, :suffix_len] = suffix
                                resp_mask_padded[j, :suffix_len] = 1.0

                            # 重新计算 Position IDs (因为我们改变了 Padding 结构，必须重新生成)
                            # 假设 compute_position_id_with_mask 可用且正确处理 Left Padding mask
                            position_ids_padded = compute_position_id_with_mask(attention_mask_padded)

                            logprob_batch = DataProto.from_dict(
                                tensors={
                                    "responses": responses_padded,
                                    "input_ids": input_ids_padded,
                                    "attention_mask": attention_mask_padded,
                                    "position_ids": position_ids_padded,
                                },
                                auto_padding=True
                            )
                            
                            # 计算概率 (不需要梯度)
                            with torch.no_grad():
                                logprob_out = self.actor_rollout_wg.compute_log_prob(logprob_batch)
                            
                            # output 通常是 [Batch, Max_Resp_Len]
                            suffix_log_probs = logprob_out.batch["old_log_probs"] 

                            # 计算 Sum (乘 mask 去除 pad 的影响)
                            suffix_logprob_sums = (suffix_log_probs * resp_mask_padded).sum(dim=-1).tolist()
                            
                            for j, sample_idx in enumerate(entry_map_back):
                                gt_suffix_logprob_sums_per_sample[sample_idx].append(float(suffix_logprob_sums[j]))
                    
                    except Exception as e:
                        # Do not break training for this auxiliary computation.
                        print(f"[WARN] chunk-wise ground-truth logprob computation failed: {e}")
                        import traceback
                        traceback.print_exc()
                    slopes_list = []
                    for probs in gt_suffix_logprob_sums_per_sample:
                        if len(probs) < 2:
                            slopes_list.append(0.0)
                        else:
                            # y = kx + b
                            # fits closest line through the log-prob trajectory
                            k, _ = np.polyfit(np.arange(len(probs)), np.array(probs), 1)
                            slopes_list.append(float(k))

                    new_batch.non_tensor_batch["chunk_gt_slopes"] = np.array(slopes_list, dtype=np.float32)
                    new_batch.non_tensor_batch["chunk_gt_suffix_logprob_sums"] = np.array(gt_suffix_logprob_sums_per_sample, dtype=object)
                    
                    
                    with marked_timer("reward", timing_raw, "yellow"):
                        # compute scores. Support both model and function-based.
                        # We first compute the scores using reward model. Then, we call reward_fn to combine
                        # the results from reward model and rule-based results.
                        if self.use_rm:
                            # we first compute reward model score
                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)
                            new_batch = new_batch.union(reward_tensor)

                        # we combine with rule-based rm
                        reward_extra_infos_dict: dict[str, list]
                        try:
                            reward_result = self.reward_fn(new_batch, return_dict=True)
                            reward_tensor = reward_result["reward_tensor"]
                            reward_extra_infos_dict = reward_result.get("reward_extra_info", {})
                            original_scores = reward_extra_infos_dict.get("original_scores", None)
                        except Exception as e:
                            print(f"Error in reward_fn: {e}")
                            reward_tensor = self.reward_fn(new_batch)
                            reward_extra_infos_dict = {}

                        new_batch.batch["token_level_scores"] = reward_tensor
                        new_batch.batch["original_scores"] = original_scores
                        if reward_extra_infos_dict:
                            new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        # compute rewards. apply_kl_penalty if available
                        if self.config.algorithm.use_kl_in_reward:
                            new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)  # TODO: This will be cleared if we use multiple genenration batches
                        else:
                            new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

                    if not self.config.algorithm.filter_groups.enable:
                        batch = new_batch
                    else:  # NOTE: When prompts after filtering is less than train batch size,
                        # we skip to the next generation batch
                        metric_name = self.config.algorithm.filter_groups.metric
                        if metric_name == "seq_final_reward":
                            # Turn to numpy for easier filtering
                            new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
                        elif metric_name == "seq_reward":
                            new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()

                        # Collect the sequence reward for each trajectory
                        prompt_uid2metric_vals = defaultdict(list)
                        for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
                            prompt_uid2metric_vals[uid].append(metric_val)

                        prompt_uid2metric_std = {}
                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

                        kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
                        num_prompt_in_batch += len(kept_prompt_uids)

                        kept_traj_idxs = []
                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
                            if traj_from_prompt_uid in kept_prompt_uids:
                                kept_traj_idxs.append(idx)

                        new_batch = new_batch[kept_traj_idxs]
                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])

                        prompt_bsz = self.config.data.train_batch_size
                        if num_prompt_in_batch < prompt_bsz:
                            print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                                print(f"{num_gen_batches=}. Keep generating...")
                                progress_bar.update(1)
                                continue
                            else:
                                raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
                        else:
                            # Align the batch
                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
                            batch = batch[:traj_bsz]

                    # === Updating ===

                    batch.batch["response_mask"] = compute_response_mask(batch)

                    # Balance the number of valid tokens across DP ranks.
                    # NOTE: This usually changes the order of data in the `batch`,
                    # which won't affect the advantage calculation (since it's based on uid),
                    # but might affect the loss calculation (due to the change of mini-batching).
                    # TODO: Decouple the DP balancing and mini-batching.
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    # recompute old_log_probs
                    with marked_timer("old_log_prob", timing_raw, "blue"):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        entropys = old_log_prob.batch["entropys"]
                        response_masks = batch.batch["response_mask"]
                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                        old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
                        metrics.update(old_log_prob_metrics)
                        old_log_prob.batch.pop("entropys")
                        batch = batch.union(old_log_prob)

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with marked_timer("ref", timing_raw, "olive"):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # compute values
                    if self.use_critic:
                        with marked_timer("values", timing_raw, "cyan"):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with marked_timer("adv", timing_raw, "brown"):
                        # compute advantages, executed on the driver process
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                        )

                    # update critic
                    if self.use_critic:
                        with marked_timer("update_critic", timing_raw, "pink"):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    # implement critic warmup
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # update actor
                        with marked_timer("update_actor", timing_raw, "red"):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with marked_timer("testing", timing_raw, "green"):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                            
                            # Check and save best model
                            self._check_and_save_best_model(val_metrics)

                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                        with marked_timer("save_checkpoint", timing_raw, "green"):
                            self._save_checkpoint()

                # collect metrics
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                timing_raw = defaultdict(float)  # clear timing

                if self.config.trainer.get("save_train_data", False):
                    # Save training data for each step
                    self._save_training_data(batch, self.global_steps)
                    
                metrics["train/num_gen_batches"] = num_gen_batches
                batch = None
                num_prompt_in_batch = 0
                num_gen_batches = 0

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if do_profile:
                    self.actor_rollout_wg.stop_profile()
                    if self.use_reference_policy:
                        self.ref_policy_wg.stop_profile()
                    if self.use_critic:
                        self.critic_wg.stop_profile()
                    if self.use_rm:
                        self.rm_wg.stop_profile()

                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return

                progress_bar.update(1)
                self.global_steps += 1
    
    def _save_training_data(self, batch, step):
        """
        Save training data for each step including prompts, responses, finish reasons, and accuracy.
        """
        if batch is None:
            return
            
        from verl.utils.fs import local_mkdir_safe
        import json
        
        # Create training data directory
        training_data_dir = os.path.join(self.config.trainer.default_local_dir, "training_data")
        local_mkdir_safe(training_data_dir)
        
        # Save data for this step
        step_data_file = os.path.join(training_data_dir, f"step_{step}_traindata.jsonl")
        
        # Extract data from batch
        prompts = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
        responses = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)

        
        final_scores = batch.batch["token_level_scores"].sum(dim=-1).tolist()
       
        final_scores = batch.batch["token_level_scores"].sum(dim=-1).tolist()
       
        response_info = _compute_response_info(batch)
    
        response_length = response_info["response_length"].tolist()
        chunk_gt_slopes = batch.non_tensor_batch["chunk_gt_slopes"].tolist()
        chunk_gt_suffix_logprob_sums = batch.non_tensor_batch["chunk_gt_suffix_logprob_sums"].tolist()
        # Get UIDs
        uids = batch.non_tensor_batch.get("uid", [f"unknown_{i}" for i in range(len(prompts))])
        
        # Group data by UID
        uid_data = {}
        for i in range(len(prompts)):
            uid = str(uids[i]) if i < len(uids) else f"unknown_{i}"
            
            if uid not in uid_data:
                uid_data[uid] = {
                    "step": step,
                    "uid": uid,
                    "input": prompts[i],  # 单个值
                    "ground_truth": batch[i].non_tensor_batch["reward_model"]["ground_truth"],  # 单个值
                    "responses": [],  # 列表
                    "chunk_gt_slopes": [],
                    "chunk_gt_suffix_logprob_sums": [],
                    "final_scores": [],
                    "response_length": [],
                }
            
            # 添加到列表中
            uid_data[uid]["responses"].append(responses[i])
            uid_data[uid]["final_scores"].append(final_scores[i])
            uid_data[uid]["response_length"].append(response_length[i])
            uid_data[uid]["chunk_gt_slopes"].append(chunk_gt_slopes[i])
            uid_data[uid]["chunk_gt_suffix_logprob_sums"].append(chunk_gt_suffix_logprob_sums[i])
        # Save grouped data
        with open(step_data_file, "w", encoding="utf-8") as f:
            for uid, data in uid_data.items():
                f.write(json.dumps(data, ensure_ascii=False) + "\n")
        
        print(f"Training data saved for step {step}: {step_data_file}")


def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]:
    """
    Computes various metrics from a batch of data for PPO training.

    This function calculates metrics related to scores, rewards, advantages, returns, values,
    and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
    for each metric category.

    Args:
        batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
        use_critic: Whether to include critic-specific metrics. Defaults to True.

    Returns:
        A dictionary of metrics including:
            - critic/score/mean, max, min: Statistics about sequence scores
            - critic/rewards/mean, max, min: Statistics about sequence rewards
            - critic/advantages/mean, max, min: Statistics about advantages
            - critic/returns/mean, max, min: Statistics about returns
            - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
            - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
            - response_length/mean, max, min, clip_ratio: Statistics about response lengths
            - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
    """
    sequence_score = batch.batch["token_level_scores"].sum(-1)
    sequence_reward = batch.batch["token_level_rewards"].sum(-1)
    original_scores = batch.batch["original_scores"]

    advantages = batch.batch["advantages"]
    returns = batch.batch["returns"]

    max_response_length = batch.batch["responses"].shape[-1]

    prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
    response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()

    max_prompt_length = prompt_mask.size(-1)

    response_info = _compute_response_info(batch)
    prompt_length = response_info["prompt_length"]
    response_length = response_info["response_length"]

    valid_adv = torch.masked_select(advantages, response_mask)
    valid_returns = torch.masked_select(returns, response_mask)

    if use_critic:
        values = batch.batch["values"]
        valid_values = torch.masked_select(values, response_mask)
        return_diff_var = torch.var(valid_returns - valid_values)
        return_var = torch.var(valid_returns)
    chunk_gt_slopes = batch.non_tensor_batch["chunk_gt_slopes"]
    chunk_gt_suffix_logprob_sums = batch.non_tensor_batch["chunk_gt_suffix_logprob_sums"]
    # Convert numpy arrays to torch tensors
    chunk_gt_slopes = torch.from_numpy(chunk_gt_slopes) if isinstance(chunk_gt_slopes, np.ndarray) else chunk_gt_slopes
    # chunk_gt_suffix_logprob_sums is an object array containing lists, need to flatten first
    if isinstance(chunk_gt_suffix_logprob_sums, np.ndarray) and chunk_gt_suffix_logprob_sums.dtype == object:
        # Flatten all lists into a single array
        arrays_to_concat = [np.array(lst, dtype=np.float32) for lst in chunk_gt_suffix_logprob_sums if len(lst) > 0]
        if len(arrays_to_concat) > 0:
            flattened_values = np.concatenate(arrays_to_concat)
            chunk_gt_suffix_logprob_sums = torch.from_numpy(flattened_values)
        else:
            # If empty, create a dummy tensor
            chunk_gt_suffix_logprob_sums = torch.tensor([0.0], dtype=torch.float32)
    elif isinstance(chunk_gt_suffix_logprob_sums, np.ndarray):
        chunk_gt_suffix_logprob_sums = torch.from_numpy(chunk_gt_suffix_logprob_sums)
    metrics = {
        # chunk gt slopes
        "chunk_gt_slopes/mean": torch.mean(chunk_gt_slopes).detach().item(),
        "chunk_gt_slopes/max": torch.max(chunk_gt_slopes).detach().item(),
        "chunk_gt_slopes/min": torch.min(chunk_gt_slopes).detach().item(),

        # chunk gt suffix logprob sums
        "chunk_gt_suffix_logprob_sums/mean": torch.mean(chunk_gt_suffix_logprob_sums).detach().item(),
        "chunk_gt_suffix_logprob_sums/max": torch.max(chunk_gt_suffix_logprob_sums).detach().item(),
        "chunk_gt_suffix_logprob_sums/min": torch.min(chunk_gt_suffix_logprob_sums).detach().item(),

        # score
        "critic/score/mean": torch.mean(sequence_score).detach().item(),
        "critic/score/max": torch.max(sequence_score).detach().item(),
        "critic/score/min": torch.min(sequence_score).detach().item(),
        "critic/acc/mean": torch.mean(original_scores).detach().item(),
        # reward
        "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
        "critic/rewards/max": torch.max(sequence_reward).detach().item(),
        "critic/rewards/min": torch.min(sequence_reward).detach().item(),
        # adv
        "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
        "critic/advantages/max": torch.max(valid_adv).detach().item(),
        "critic/advantages/min": torch.min(valid_adv).detach().item(),
        # returns
        "critic/returns/mean": torch.mean(valid_returns).detach().item(),
        "critic/returns/max": torch.max(valid_returns).detach().item(),
        "critic/returns/min": torch.min(valid_returns).detach().item(),
        **(
            {
                # values
                "critic/values/mean": torch.mean(valid_values).detach().item(),
                "critic/values/max": torch.max(valid_values).detach().item(),
                "critic/values/min": torch.min(valid_values).detach().item(),
                # vf explained var
                "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
            }
            if use_critic
            else {}
        ),
        # response length
        "response_length/mean": torch.mean(response_length).detach().item(),
        "response_length/max": torch.max(response_length).detach().item(),
        "response_length/min": torch.min(response_length).detach().item(),
        "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
        # prompt length
        "prompt_length/mean": torch.mean(prompt_length).detach().item(),
        "prompt_length/max": torch.max(prompt_length).detach().item(),
        "prompt_length/min": torch.min(prompt_length).detach().item(),
        "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
    }
    return metrics