# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import json
import math
import os
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Type, Dict
from copy import deepcopy

import ray
import numpy as np
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.tracking import ValidationGenerationsLogger
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from sklearn.linear_model import LinearRegression
import numpy as np
from compressor_client import compress_prompt
import pdb

WorkerType = Type[Worker]


class Role(Enum):
    """
    To create more roles dynamically, you can subclass Role and add new members
    """
    Actor = 0
    Rollout = 1
    ActorRollout = 2
    Critic = 3
    RefPolicy = 4
    RewardModel = 5
    ActorRolloutRef = 6


class AdvantageEstimator(str, Enum):
    """
    Using an enumeration class to avoid spelling errors in adv_estimator
    """
    GAE = 'gae'
    GRPO = 'grpo'
    REINFORCE_PLUS_PLUS = 'reinforce_plus_plus'
    REMAX = 'remax'
    RLOO = 'rloo'
    LGAE = "lgae"


@dataclass
class ResourcePoolManager:
    """
    Define a resource pool specification. Resource pool will be initialized first.
    Mapping
    """
    resource_pool_spec: dict[str, list[int]]
    mapping: dict[Role, str]
    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)

    def create_resource_pool(self):
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
            # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
            resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
                                            use_gpu=True,
                                            max_colocate_count=1,
                                            name_prefix=resource_pool_name)
            self.resource_pool_dict[resource_pool_name] = resource_pool

        self._check_resource_available()

    def get_resource_pool(self, role: Role) -> RayResourcePool:
        """Get the resource pool of the worker_cls"""
        return self.resource_pool_dict[self.mapping[role]]

    def get_n_gpus(self) -> int:
        """Get the number of gpus in this cluster."""
        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])

    def _check_resource_available(self):
        """Check if the resource pool can be satisfied in this ray cluster."""
        node_available_resources = ray.state.available_resources_per_node()
        node_available_gpus = {node: node_info.get('GPU', 0) for node, node_info in node_available_resources.items()}

        # check total required gpus can be satisfied
        total_available_gpus = sum(node_available_gpus.values())
        total_required_gpus = sum(
            [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
        if total_available_gpus < total_required_gpus:
            raise ValueError(
                f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}")

        # check each resource pool can be satisfied, O(#resource_pools * #nodes)
        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
            num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
            for node, available_gpus in node_available_gpus.items():
                if available_gpus >= num_gpus:
                    node_available_gpus[node] -= num_gpus
                    num_nodes -= 1
                    if num_nodes == 0:
                        break
            if num_nodes > 0:
                raise ValueError(
                    f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster"
                )


import torch
from verl.utils.torch_functional import masked_mean


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'):
    responses = data.batch['responses']
    response_length = responses.size(1)
    token_level_scores = data.batch['token_level_scores']
    batch_size = data.batch.batch_size[0]
    attention_mask = data.batch['attention_mask']
    response_mask = attention_mask[:, -response_length:]

    # compute kl between ref_policy and current policy
    if 'ref_log_prob' in data.batch.keys():
        kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
                                    kl_penalty=kl_penalty)  # (batch_size, response_length)
        kld = kld * response_mask
        beta = kl_ctrl.value
    else:
        beta = 0
        kld = torch.zeros_like(response_mask, dtype=torch.float32)

    token_level_rewards = token_level_scores - beta * kld

    current_kl = masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence
    current_kl = torch.mean(current_kl, dim=0).item()

    # according to XXXX
    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
    data.batch['token_level_rewards'] = token_level_rewards

    metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta}

    return data, metrics

def normalize_probs(probs_list):
    """归一化概率到0-1区间"""
    normalized_probs = []
    for chunk_probs in probs_list:
        chunk_probs = np.array(chunk_probs)
        min_val = np.min(chunk_probs)
        max_val = np.max(chunk_probs)
        if max_val == min_val:
            normalized_chunk = np.ones_like(chunk_probs)
        else:
            normalized_chunk = (chunk_probs - min_val) / (max_val - min_val)
        normalized_probs.append(normalized_chunk)
    return normalized_probs

def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, lambda_c=1.0, lambda_w_n=1.0, lambda_w_e=1.0, tokenizer=None, i_threshold=0.3, initial_slope=None, dynamic_lambda_w_e=None, ppo=False):
    # prepare response group
    # TODO: add other ways to estimate advantages
    if adv_estimator == AdvantageEstimator.GAE:
        values = data.batch['values']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        token_level_rewards = data.batch['token_level_rewards']
        advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards,
                                                                      values=values,
                                                                      eos_mask=response_mask,
                                                                      gamma=gamma,
                                                                      lam=lam)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    elif adv_estimator == AdvantageEstimator.GRPO:
        token_level_rewards = data.batch['token_level_rewards']
        index = data.non_tensor_batch['uid']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards,
                                                                        eos_mask=response_mask,
                                                                        index=index)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
        token_level_rewards = data.batch['token_level_rewards']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
            token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    elif adv_estimator == AdvantageEstimator.REMAX:
        token_level_rewards = data.batch['token_level_rewards']
        index = data.non_tensor_batch['uid']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]

        reward_baselines = data.batch['reward_baselines']

        advantages, returns = core_algos.compute_remax_outcome_advantage(token_level_rewards=token_level_rewards,
                                                                         reward_baselines=reward_baselines,
                                                                         eos_mask=response_mask)

        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
    elif adv_estimator == AdvantageEstimator.RLOO:
        token_level_rewards = data.batch['token_level_rewards']
        index = data.non_tensor_batch['uid']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards=token_level_rewards,
                                                                        eos_mask=response_mask,
                                                                        index=index)
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns

    elif adv_estimator == AdvantageEstimator.LGAE:
        values = data.batch['values']
        responses = data.batch['responses']
        response_length = responses.size(-1)
        attention_mask = data.batch['attention_mask']
        response_mask = attention_mask[:, -response_length:]
        token_level_rewards = data.batch['token_level_rewards']
        old_log_probs = data.batch['old_log_probs']
        ref_responses = data.batch['ref_responses']
                
        # 计算当前响应和参考响应的长度
        current_lengths = response_mask.sum(dim=1)  # [batch_size]
        ref_response_mask = data.batch['ref_attention_mask'][:, -token_level_rewards.size(1):]
        ref_lengths = ref_response_mask.sum(dim=1)  # [batch_size]

        # 计算长度比率项
        ratio_term = (current_lengths / ref_lengths).unsqueeze(1)  # [batch_size, 1]
        
        # 初始化token级别的ratio_term（全零）
        ratio_term_token = torch.zeros_like(token_level_rewards)  # [batch_size, seq_len]
        
        
        # 初始化存储所有样本结果的列表
        # 原始响应的统计
        all_normalized_probs = []
        all_total_tokens = []
        all_low_prob_tokens = []
        all_high_prob_tokens = []

        # 参考响应的统计
        ref_normalized_probs = []
        ref_total_tokens = []
        ref_low_prob_tokens = []
        ref_high_prob_tokens = []

        # 首先处理原始响应
        print("\n===== 原始响应的统计信息 =====")

        for i, (response, length) in enumerate(zip(responses, current_lengths)):
            if length > 0:
                # 使用当前长度直接截取有效token
                valid_tokens = response[:length]
                # 解码有效token
                text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
                
                result = compress_prompt(text, 0.6)

                
                # 获取原始概率并归一化
                sample_original_probs = result["original_probs"]
                sample_normalized_probs = normalize_probs(sample_original_probs)
                
                # 计算当前样本的统计信息
                for probs in sample_normalized_probs:
                    total_tokens = len(probs)
                    low_prob_tokens = len([p for p in probs if p < i_threshold])
                    high_prob_tokens = total_tokens - low_prob_tokens
                    
                    # 存储结果
                    all_normalized_probs.append(probs)
                    all_total_tokens.append(total_tokens)
                    all_low_prob_tokens.append(low_prob_tokens)
                    all_high_prob_tokens.append(high_prob_tokens)
                    
                    # 打印当前样本的统计信息
                    print(f"原始样本 {i+1}: Total tokens: {total_tokens}, Low prob: {low_prob_tokens} ({low_prob_tokens/total_tokens*100:.2f}%), High prob: {high_prob_tokens} ({high_prob_tokens/total_tokens*100:.2f}%)")
            else:
                # 处理空序列情况
                all_normalized_probs.append([])
                all_total_tokens.append(0)
                all_low_prob_tokens.append(0)
                all_high_prob_tokens.append(0)
                print(f"原始样本 {i+1}: Empty sequence")

        # 然后处理参考响应
        print("\n===== 参考响应的统计信息 =====")
        for i, (ref_response, length) in enumerate(zip(ref_responses, ref_lengths)):
            if length > 0:
                # 使用当前长度直接截取有效token
                valid_tokens = ref_response[:length]
                # 解码有效token
                text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
                
                result = compress_prompt(text, 0.6)

                # 获取原始概率并归一化
                sample_original_probs = result["original_probs"]
                sample_normalized_probs = normalize_probs(sample_original_probs)
                
                # 计算当前样本的统计信息
                for probs in sample_normalized_probs:
                    total_tokens = len(probs)
                    low_prob_tokens = len([p for p in probs if p < i_threshold])
                    high_prob_tokens = total_tokens - low_prob_tokens
                    
                    # 存储结果
                    ref_normalized_probs.append(probs)
                    ref_total_tokens.append(total_tokens)
                    ref_low_prob_tokens.append(low_prob_tokens)
                    ref_high_prob_tokens.append(high_prob_tokens)
                    
                    # 打印当前样本的统计信息
                    print(f"参考样本 {i+1}: Total tokens: {total_tokens}, Low prob: {low_prob_tokens} ({low_prob_tokens/total_tokens*100:.2f}%), High prob: {high_prob_tokens} ({high_prob_tokens/total_tokens*100:.2f}%)")
            else:
                # 处理空序列情况
                ref_normalized_probs.append([])
                ref_total_tokens.append(0)
                ref_low_prob_tokens.append(0)
                ref_high_prob_tokens.append(0)
                print(f"参考样本 {i+1}: Empty sequence")

        # import pdb
        # pdb.set_trace()

        # 将结果转换为张量，方便后续处理
        # 原始响应的张量
        all_total_tokens_tensor = torch.tensor(all_total_tokens, device=responses.device)
        all_low_prob_tokens_tensor = torch.tensor(all_low_prob_tokens, device=responses.device)
        all_high_prob_tokens_tensor = torch.tensor(all_high_prob_tokens, device=responses.device)

        # 参考响应的张量
        ref_total_tokens_tensor = torch.tensor(ref_total_tokens, device=responses.device)
        ref_low_prob_tokens_tensor = torch.tensor(ref_low_prob_tokens, device=responses.device)
        ref_high_prob_tokens_tensor = torch.tensor(ref_high_prob_tokens, device=responses.device)

        # 计算token数量的平均值
        all_low_prob_tokens_tensor_mean = all_low_prob_tokens_tensor.float().mean().item()
        all_high_prob_tokens_tensor_mean = all_high_prob_tokens_tensor.float().mean().item()
        ref_low_prob_tokens_tensor_mean = ref_low_prob_tokens_tensor.float().mean().item()
        ref_high_prob_tokens_tensor_mean = ref_high_prob_tokens_tensor.float().mean().item()

        # 计算比例张量
        # 原始响应的比例
        all_low_prob_ratio_tensor = torch.zeros_like(all_total_tokens_tensor, dtype=torch.float32)
        all_high_prob_ratio_tensor = torch.zeros_like(all_total_tokens_tensor, dtype=torch.float32)

        # 参考响应的比例
        ref_low_prob_ratio_tensor = torch.zeros_like(ref_total_tokens_tensor, dtype=torch.float32)
        ref_high_prob_ratio_tensor = torch.zeros_like(ref_total_tokens_tensor, dtype=torch.float32)

        # 计算比例，避免除零错误
        valid_indices = all_total_tokens_tensor > 0
        if valid_indices.any():
            all_low_prob_ratio_tensor[valid_indices] = all_low_prob_tokens_tensor[valid_indices].float() / all_total_tokens_tensor[valid_indices].float()
            all_high_prob_ratio_tensor[valid_indices] = all_high_prob_tokens_tensor[valid_indices].float() / all_total_tokens_tensor[valid_indices].float()

        ref_valid_indices = ref_total_tokens_tensor > 0
        if ref_valid_indices.any():
            ref_low_prob_ratio_tensor[ref_valid_indices] = ref_low_prob_tokens_tensor[ref_valid_indices].float() / ref_total_tokens_tensor[ref_valid_indices].float()
            ref_high_prob_ratio_tensor[ref_valid_indices] = ref_high_prob_tokens_tensor[ref_valid_indices].float() / ref_total_tokens_tensor[ref_valid_indices].float()

        # 计算高低概率比例
        # 初始化结果张量 - 这些都是 [batch_size] 维度的张量
        high_to_ref_high_ratio = torch.zeros_like(all_total_tokens_tensor, dtype=torch.float32)  # [batch_size]
        low_to_ref_low_ratio = torch.zeros_like(all_total_tokens_tensor, dtype=torch.float32)    # [batch_size]

        # 计算每个样本的比例，避免除零错误
        valid_for_ratio = (valid_indices & ref_valid_indices & (ref_high_prob_tokens_tensor > 0) & (ref_low_prob_tokens_tensor > 0))

        if valid_for_ratio.any():
            # high_prob / ref_high_prob - 表示相比参考模型，当前模型高概率token的相对数量
            high_to_ref_high_ratio[valid_for_ratio] = all_high_prob_tokens_tensor[valid_for_ratio].float() / ref_high_prob_tokens_tensor[valid_for_ratio].float()
            
            # low_prob / ref_low_prob - 表示相比参考模型，当前模型低概率token的相对数量
            low_to_ref_low_ratio[valid_for_ratio] = all_low_prob_tokens_tensor[valid_for_ratio].float() / ref_low_prob_tokens_tensor[valid_for_ratio].float()

        # 计算平均比例
        avg_high_to_ref_high = high_to_ref_high_ratio[valid_for_ratio].mean().item() if valid_for_ratio.any() else 0
        avg_low_to_ref_low = low_to_ref_low_ratio[valid_for_ratio].mean().item() if valid_for_ratio.any() else 0

        
        # 打印维度变化和比例结果
        print("\n===== 相对比例分析 =====")
        print(f"高概率token比例 (current/ref): {avg_high_to_ref_high:.4f}")
        print(f"低概率token比例 (current/ref): {avg_low_to_ref_low:.4f}")

        # 打印汇总统计信息
        print("\n===== 汇总统计 =====")
        print("原始响应:")
        avg_low_prob_ratio = all_low_prob_ratio_tensor.mean().item()
        avg_high_prob_ratio = all_high_prob_ratio_tensor.mean().item()
        print(f"- 平均低概率token比例: {avg_low_prob_ratio:.4f} ({avg_low_prob_ratio*100:.2f}%)")
        print(f"- 平均高概率token比例: {avg_high_prob_ratio:.4f} ({avg_high_prob_ratio*100:.2f}%)")
        print(f"- 平均token总数: {all_total_tokens_tensor.float().mean().item():.1f}")

        print("\n参考响应:")
        ref_avg_low_prob_ratio = ref_low_prob_ratio_tensor.mean().item()
        ref_avg_high_prob_ratio = ref_high_prob_ratio_tensor.mean().item()
        print(f"- 平均低概率token比例: {ref_avg_low_prob_ratio:.4f} ({ref_avg_low_prob_ratio*100:.2f}%)")
        print(f"- 平均高概率token比例: {ref_avg_high_prob_ratio:.4f} ({ref_avg_high_prob_ratio*100:.2f}%)")
        print(f"- 平均token总数: {ref_total_tokens_tensor.float().mean().item():.1f}")

        
        # 将结果添加到指标字典中，以便记录到wandb
        if 'metrics' not in data.meta_info:
            data.meta_info['metrics'] = {}

        data.meta_info['metrics'].update({
            'token_stats/original_low_prob_ratio': avg_low_prob_ratio,
            'token_stats/original_high_prob_ratio': avg_high_prob_ratio,
            'token_stats/original_avg_tokens_mean': all_total_tokens_tensor.float().mean().item(),
            'token_stats/ref_low_prob_ratio': ref_avg_low_prob_ratio,
            'token_stats/ref_high_prob_ratio': ref_avg_high_prob_ratio,
            'token_stats/ref_avg_tokens_mean': ref_total_tokens_tensor.float().mean().item(),
            'token_stats/high_to_ref_high_ratio': avg_high_to_ref_high,
            'token_stats/low_to_ref_low_ratio': avg_low_to_ref_low,
            'token_stats/original_low_prob_tokens_mean': all_low_prob_tokens_tensor_mean,
            'token_stats/original_high_prob_tokens_mean': all_high_prob_tokens_tensor_mean,
            'token_stats/ref_low_prob_tokens_mean': ref_low_prob_tokens_tensor_mean,
            'token_stats/ref_high_prob_tokens_mean': ref_high_prob_tokens_tensor_mean
        })
        
        
        # 获取token级别的预测掩码
        token_scores = data.batch['token_level_scores']       # [batch_size, seq_len]
        
        # 获取正确和错误样本的掩码
        has_correct_prediction = (token_scores == 1).any(dim=1)  # [batch_size]
        correct_samples_mask = has_correct_prediction
        wrong_samples_mask = ~has_correct_prediction

        # 统计正负样本数量
        correct_samples_count = correct_samples_mask.sum().item()
        wrong_samples_count = wrong_samples_mask.sum().item()
        total_samples = len(responses)

        print(f"\n===== 样本统计 =====")
        print(f"正样本数量: {correct_samples_count}/{total_samples} ({correct_samples_count/total_samples*100:.2f}%)")
        print(f"负样本数量: {wrong_samples_count}/{total_samples} ({wrong_samples_count/total_samples*100:.2f}%)")
        
        # 分别计算正样本和负样本的高低概率token统计
        if valid_indices.any():
            # 正样本统计
            correct_indices = valid_indices & correct_samples_mask
            if correct_indices.any():
                # 高概率token在正样本中的统计
                correct_high_tokens = all_high_prob_tokens_tensor[correct_indices]
                correct_high_tokens_mean = correct_high_tokens.float().mean().item()
                
                # 低概率token在正样本中的统计
                correct_low_tokens = all_low_prob_tokens_tensor[correct_indices]
                correct_low_tokens_mean = correct_low_tokens.float().mean().item()
                
                # 计算正样本中的高低概率token比例
                correct_total_tokens = all_total_tokens_tensor[correct_indices]
                correct_high_ratio = (correct_high_tokens.float() / correct_total_tokens.float()).mean().item()
                correct_low_ratio = (correct_low_tokens.float() / correct_total_tokens.float()).mean().item()
            else:
                # 如果没有正样本，设置默认值
                correct_high_tokens_mean = 0
                correct_low_tokens_mean = 0
                correct_high_ratio = 0
                correct_low_ratio = 0
            
            # 负样本统计
            wrong_indices = valid_indices & wrong_samples_mask
            if wrong_indices.any():
                # 高概率token在负样本中的统计
                wrong_high_tokens = all_high_prob_tokens_tensor[wrong_indices]
                wrong_high_tokens_mean = wrong_high_tokens.float().mean().item()
                
                # 低概率token在负样本中的统计
                wrong_low_tokens = all_low_prob_tokens_tensor[wrong_indices]
                wrong_low_tokens_mean = wrong_low_tokens.float().mean().item()
                
                # 计算负样本中的高低概率token比例
                wrong_total_tokens = all_total_tokens_tensor[wrong_indices]
                wrong_high_ratio = (wrong_high_tokens.float() / wrong_total_tokens.float()).mean().item()
                wrong_low_ratio = (wrong_low_tokens.float() / wrong_total_tokens.float()).mean().item()
            else:
                # 如果没有负样本，设置默认值
                wrong_high_tokens_mean = 0
                wrong_low_tokens_mean = 0
                wrong_high_ratio = 0
                wrong_low_ratio = 0
        else:
            # 如果没有有效样本，设置默认值
            correct_high_tokens_mean = 0
            correct_low_tokens_mean = 0
            correct_high_ratio = 0
            correct_low_ratio = 0
            wrong_high_tokens_mean = 0
            wrong_low_tokens_mean = 0
            wrong_high_ratio = 0
            wrong_low_ratio = 0
        
        correct_cos_values = []
        wrong_cos_values = []

        # 对每个样本分别处理
        for i in range(token_scores.shape[0]):
            # 找到当前样本中score为1的位置
            correct_positions = (token_scores[i] == 1).nonzero(as_tuple=True)[0]
            
            if len(correct_positions) > 0:
                # 如果有正确预测的位置，将ratio_term加到所有score为1的位置
                # ratio_term_token[i, correct_positions] = torch.clamp(ratio_term[i] * lambda_c, min=0.0, max=5.0)

                cos_value = lambda_c * torch.cos(torch.clamp(low_to_ref_low_ratio[i], min=0.0, max=math.pi / 2))  # 值域 [0，1]
                ratio_term_token[i, correct_positions] = cos_value
                correct_cos_values.append(cos_value.item() + 1)
            else:
                # 如果没有正确预测的位置，找到response中的最后一个有效位置
                valid_positions = response_mask[i].nonzero(as_tuple=True)[0]
                if len(valid_positions) > 0:
                    last_valid_pos = valid_positions[-1]
                    # ratio_term_token[i, last_valid_pos] = torch.clamp(ratio_term[i] * lambda_w - 1.0, min=-1.0, max=0.0)
                    if initial_slope is not None:
                        cos_value = lambda_w_n * (torch.cos(torch.clamp(low_to_ref_low_ratio[i], min=0.0, max=math.pi / 2)) - 1) + dynamic_lambda_w_e * high_to_ref_high_ratio[i] - lambda_w_e
                    else:
                        # cos_value = lambda_w_n * (torch.cos(torch.clamp(low_to_ref_low_ratio[i], min=0.0, max=math.pi / 2)) - 1) - lambda_w_e * torch.cos(torch.clamp(high_to_ref_high_ratio[i], min=0.0, max=math.pi / 2))
                        cos_value = lambda_w_n * (torch.cos(torch.clamp(low_to_ref_low_ratio[i], min=0.0, max=math.pi / 2)) - 1) + (lambda_w_e / (math.pi / 2)) * high_to_ref_high_ratio[i] - lambda_w_e
                    ratio_term_token[i, last_valid_pos] = torch.clamp(cos_value, max=0.0)
                    wrong_cos_values.append(cos_value.item())

        # 确保只应用到有效的response位置
        ratio_term_token = ratio_term_token * response_mask

        # 将ratio_term_token添加到token_level_rewards
        token_level_rewards = token_level_rewards + ratio_term_token
        
        # 转换为张量以便计算统计信息
        correct_cos_values = torch.tensor(correct_cos_values) if correct_cos_values else torch.tensor([])
        wrong_cos_values = torch.tensor(wrong_cos_values) if wrong_cos_values else torch.tensor([])

        # 计算正负样本的cos_value统计信息
        correct_cos_mean = correct_cos_values.mean().item() if len(correct_cos_values) > 0 else 0
        correct_cos_min = correct_cos_values.min().item() if len(correct_cos_values) > 0 else 0
        correct_cos_max = correct_cos_values.max().item() if len(correct_cos_values) > 0 else 0

        wrong_cos_mean = wrong_cos_values.mean().item() if len(wrong_cos_values) > 0 else 0
        wrong_cos_min = wrong_cos_values.min().item() if len(wrong_cos_values) > 0 else 0
        wrong_cos_max = wrong_cos_values.max().item() if len(wrong_cos_values) > 0 else 0
        
        data.meta_info['metrics'].update({
            'samples/correct_count': correct_samples_count,
            'samples/wrong_count': wrong_samples_count,
            'samples/correct_ratio': correct_samples_count / total_samples,
            
            'cos_value/correct_mean': correct_cos_mean,
            'cos_value/correct_min': correct_cos_min,
            'cos_value/correct_max': correct_cos_max,
            
            'cos_value/wrong_mean': wrong_cos_mean,
            'cos_value/wrong_min': wrong_cos_min,
            'cos_value/wrong_max': wrong_cos_max,
            
            # 添加正样本高低概率token统计
            'token_stats/correct_high_tokens_mean': correct_high_tokens_mean,
            'token_stats/correct_low_tokens_mean': correct_low_tokens_mean,
            'token_stats/correct_high_ratio': correct_high_ratio,
            'token_stats/correct_low_ratio': correct_low_ratio,
            
            # 添加负样本高低概率token统计
            'token_stats/wrong_high_tokens_mean': wrong_high_tokens_mean,
            'token_stats/wrong_low_tokens_mean': wrong_low_tokens_mean,
            'token_stats/wrong_high_ratio': wrong_high_ratio,
            'token_stats/wrong_low_ratio': wrong_low_ratio,
        })

        
        
        if ppo:
            values = data.batch['values']
            responses = data.batch['responses']
            response_length = responses.size(-1)
            attention_mask = data.batch['attention_mask']
            response_mask = attention_mask[:, -response_length:]
            token_level_rewards = data.batch['token_level_rewards']
        
        
        # 使用修改后的token_level_rewards计算advantages
        advantages, returns = core_algos.compute_gae_advantage_return(
            token_level_rewards=token_level_rewards,
            values=values,
            eos_mask=response_mask,
            gamma=gamma,
            lam=lam
        )

        # 更新batch中的advantages和returns
        data.batch['advantages'] = advantages
        data.batch['returns'] = returns
        
    else:
        raise NotImplementedError
    return data


@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
    with Timer(name=name, logger=None) as timer:
        yield
    timing_raw[name] = timer.last


class RayPPOTrainer(object):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    # TODO: support each role have individual ray_worker_group_cls,
    # i.e., support different backend of different role
    def __init__(self,
                 config,
                 tokenizer,
                 role_worker_mapping: dict[Role, WorkerType],
                 resource_pool_manager: ResourcePoolManager,
                 ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
                 processor=None,
                 reward_fn=None,
                 val_reward_fn=None):

        # assert torch.cuda.is_available(), 'cuda must be available on driver'

        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn

        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
        assert self.hybrid_engine, 'Currently, only support hybrid engine'

        if self.hybrid_engine:
            assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}'

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reference_policy = Role.RefPolicy in role_worker_mapping
        self.use_rm = Role.RewardModel in role_worker_mapping
        self.ray_worker_group_cls = ray_worker_group_cls
        self.validation_generations_logger = ValidationGenerationsLogger()

        # define KL control
        if self.use_reference_policy:
            if config.algorithm.kl_ctrl.type == 'fixed':
                self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef)
            elif config.algorithm.kl_ctrl.type == 'adaptive':
                assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
                self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef,
                                                               target_kl=config.algorithm.kl_ctrl.target_kl,
                                                               horizon=config.algorithm.kl_ctrl.horizon)
            else:
                raise NotImplementedError
        else:
            self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.)

        if self.config.algorithm.adv_estimator in [AdvantageEstimator.GAE, AdvantageEstimator.LGAE]:
            self.use_critic = True
        elif self.config.algorithm.adv_estimator in [
                AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX,
                AdvantageEstimator.RLOO
        ]:
            self.use_critic = False
        else:
            raise NotImplementedError

        self._validate_config()
        self._create_dataloader()
        
        # 添加用于跟踪验证分数的属性
        self.validation_scores = []
        self.validation_steps = []
        self.dynamic_lambda_w_e = self.config.algorithm.lambda_w_e
        self.initial_slope = None
        self.slope_scale = None  # 用于映射斜率的尺度
        self.low_slope_count = 0  # 用于计数连续低斜率的次数
        
        self.training_scores = []
        self.training_steps = []

    
    def _validate_config(self):
        config = self.config
        # number of GPUs total
        n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes

        # 1. Check total batch size for data correctness
        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
        assert real_train_batch_size % n_gpus == 0, \
            f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."

        # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
        # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
        def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
            if mbs is None and mbs_per_gpu is None:
                raise ValueError(f"[{name}] Please set at least one of '{name}.micro_batch_size' or "
                                 f"'{name}.micro_batch_size_per_gpu'.")

            if mbs is not None and mbs_per_gpu is not None:
                raise ValueError(f"[{name}] You have set both '{name}.micro_batch_size' AND "
                                 f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' "
                                 f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated).")

        if not config.actor_rollout_ref.actor.use_dynamic_bsz:
            # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
            check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size,
                                     config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
                                     "actor_rollout_ref.actor")

            # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
            check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size,
                                     config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
                                     "actor_rollout_ref.ref")

            #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
            check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
                                     config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
                                     "actor_rollout_ref.rollout")

        if self.use_critic and not config.critic.use_dynamic_bsz:
            # Check for critic micro-batch size conflicts
            check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu,
                                     "critic")

        # Check for reward model micro-batch size conflicts
        if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
            check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu,
                                     "reward_model")

        # Actor
        # check if train_batch_size is larger than ppo_mini_batch_size
        # if NOT dynamic_bsz, we must ensure:
        #    ppo_mini_batch_size is divisible by ppo_micro_batch_size
        #    ppo_micro_batch_size * sequence_parallel_size >= n_gpus
        if not config.actor_rollout_ref.actor.use_dynamic_bsz:
            assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
            sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1)
            if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
                assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
                assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus

        # critic
        if self.use_critic and not config.critic.use_dynamic_bsz:
            assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
            sp_size = config.critic.get('ulysses_sequence_parallel_size', 1)
            if config.critic.ppo_micro_batch_size is not None:
                assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
                assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus

        # Check if use_remove_padding is enabled when using sequence parallelism for fsdp
        if config.actor_rollout_ref.actor.strategy == 'fsdp':
            if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \
                    config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1:
                assert config.actor_rollout_ref.model.use_remove_padding, \
                    "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."

        if self.use_critic and config.critic.strategy == 'fsdp':
            if config.critic.get('ulysses_sequence_parallel_size', 1) > 1:
                assert config.critic.model.use_remove_padding, \
                    "When using sequence parallelism for critic, you must enable `use_remove_padding`."

        if config.data.get('val_batch_size', None) is not None:
            print(
                f"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves."
            )

        # check eval config
        if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
            assert config.actor_rollout_ref.rollout.temperature > 0, \
                "validation gen temperature should be greater than 0 when enabling do_sample"

        print("[validate_config] All configuration checks passed successfully!")

    def _create_dataloader(self):
        # TODO: we have to make sure the batch size is divisible by the dp size
        self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
                                         tokenizer=self.tokenizer,
                                         processor=self.processor,
                                         prompt_key=self.config.data.prompt_key,
                                         image_key=self.config.data.get('image_key', 'images'),
                                         max_prompt_length=self.config.data.max_prompt_length,
                                         filter_prompts=True,
                                         return_raw_chat=self.config.data.get('return_raw_chat', False),
                                         truncation=self.config.data.get('truncation', 'error'),
                                         filter_overlong_prompts=self.config.data.filter_overlong_prompts)
        assert self.train_dataset.truncation == self.config.data.get(
            'truncation', 'error'
        ), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
        # use sampler for better ckpt resume
        if self.config.data.shuffle:
            train_dataloader_generator = torch.Generator()
            train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
        else:
            sampler = SequentialSampler(data_source=self.train_dataset)

        self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
                                                   batch_size=self.config.data.train_batch_size,
                                                   num_workers=8,
                                                   drop_last=True,
                                                   collate_fn=collate_fn,
                                                   sampler=sampler)

        self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
                                       tokenizer=self.tokenizer,
                                       processor=self.processor,
                                       prompt_key=self.config.data.prompt_key,
                                       image_key=self.config.data.get('image_key', 'images'),
                                       max_prompt_length=self.config.data.max_prompt_length,
                                       filter_prompts=True,
                                       return_raw_chat=self.config.data.get('return_raw_chat', False),
                                       truncation=self.config.data.get('truncation', 'error'),
                                       filter_overlong_prompts=self.config.data.filter_overlong_prompts)
        assert self.val_dataset.truncation == self.config.data.get(
            'truncation', 'error'
        ), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
        self.val_dataloader = StatefulDataLoader(
            dataset=self.val_dataset,
            # Validation datasets are sent to inference engines as a whole batch,
            # which will schedule the memory themselves.
            batch_size=len(self.val_dataset),
            num_workers=8,
            shuffle=False,
            drop_last=False,
            collate_fn=collate_fn)

        assert len(self.train_dataloader) >= 1
        assert len(
            self.val_dataloader
        ) == 1, "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves."

        print(f'Size of train dataloader: {len(self.train_dataloader)}')

        # inject total_training_steps to actor/critic optim_config. This is hacky.
        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

        if self.config.trainer.total_training_steps is not None:
            total_training_steps = self.config.trainer.total_training_steps

        self.total_training_steps = total_training_steps
        print(f'Total training steps: {self.total_training_steps}')

        OmegaConf.set_struct(self.config, True)
        with open_dict(self.config):
            self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
            self.config.critic.optim.total_training_steps = total_training_steps

    def _maybe_log_val_generations(self, inputs, outputs, scores):
        """Log a table of validation samples to the configured logger (wandb or swanlab)"""

        generations_to_log = self.config.trainer.val_generations_to_log_to_wandb

        if generations_to_log == 0:
            return

        import numpy as np

        # Create tuples of (input, output, score) and sort by input text
        samples = list(zip(inputs, outputs, scores))
        samples.sort(key=lambda x: x[0])  # Sort by input text

        # Use fixed random seed for deterministic shuffling
        rng = np.random.RandomState(42)
        rng.shuffle(samples)

        # Take first N samples after shuffling
        samples = samples[:generations_to_log]

        # Log to each configured logger
        self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)

    def _validate(self):
        reward_tensor_lst = []
        data_source_lst = []

        # Lists to collect samples for the table
        sample_inputs = []
        sample_outputs = []
        sample_scores = []
        
        sample_data_sources = []
        
        # 初始化用于记录响应长度的字典
        data_source_stats = {}

        for test_data in self.val_dataloader:
            test_batch = DataProto.from_single_dict(test_data)

            # repeat test batch
            test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n,
                                           interleave=True)

            # we only do validation on rule-based rm
            if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
                return {}

            # Store original inputs
            input_ids = test_batch.batch['input_ids']
            input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
            sample_inputs.extend(input_texts)

            if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys():
                test_gen_batch = test_batch.pop(
                    batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                    non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
                )
            else:
                test_gen_batch = test_batch.pop(
                    batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                    non_tensor_batch_keys=['raw_prompt_ids'],
                )

            test_gen_batch.meta_info = {
                'eos_token_id': self.tokenizer.eos_token_id,
                'pad_token_id': self.tokenizer.pad_token_id,
                'recompute_log_prob': False,
                'do_sample': self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
                'validate': True,
            }
            print(f'test_gen_batch meta info: {test_gen_batch.meta_info}')

            # pad to be divisible by dp_size
            test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
            test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)

            # unpad
            test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
            print('validation generation end')

            # Store generated outputs
            output_ids = test_output_gen_batch.batch['responses']
            output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
            sample_outputs.extend(output_texts)

            test_batch = test_batch.union(test_output_gen_batch)

            # evaluate using reward_function
            reward_tensor = self.val_reward_fn(test_batch)

            # Store scores
            scores = reward_tensor.sum(-1).cpu().tolist()
            sample_scores.extend(scores)
            
            # 获取响应长度
            responses = test_batch.batch['responses']
            attention_mask = test_batch.batch['attention_mask']
            response_length = responses.size(1)
            response_mask = attention_mask[:, -response_length:]
            current_lengths = response_mask.sum(dim=1).cpu().numpy()  # [batch_size]
            
            # 获取data_source和样本正确性
            data_sources = test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])
            is_correct = (reward_tensor.sum(-1) > 0).cpu().numpy()  # 根据奖励判断样本是否正确
            sample_data_sources.extend(data_sources)

            # 计算高低概率token
            for i, (response, length, is_correct_sample, data_source) in enumerate(
                    zip(responses, current_lengths, is_correct, data_sources)):
                
                if length > 0:
                    # 使用当前长度直接截取有效token
                    valid_tokens = response[:length]
                    # 解码有效token
                    text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
                    
                    try:
                        # 获取压缩结果和概率信息
                        result = compress_prompt(text, 0.6)
                        
                        # 获取原始概率并归一化
                        sample_original_probs = result["original_probs"]
                        sample_normalized_probs = normalize_probs(sample_original_probs)
                        
                        # 计算当前样本的统计信息
                        total_tokens = 0
                        low_prob_tokens = 0
                        high_prob_tokens = 0
                        
                        for probs in sample_normalized_probs:
                            tokens_in_chunk = len(probs)
                            total_tokens += tokens_in_chunk
                            low_tokens_in_chunk = len([p for p in probs if p < self.config.algorithm.i_threshold])
                            low_prob_tokens += low_tokens_in_chunk
                            high_prob_tokens += tokens_in_chunk - low_tokens_in_chunk
                        
                        # 初始化data_source的统计数据
                        if data_source not in data_source_stats:
                            data_source_stats[data_source] = {
                                'total_count': 0,
                                'total_length': 0,
                                'total_high_prob': 0,
                                'total_low_prob': 0,
                                'correct_count': 0,
                                'correct_length': 0,
                                'correct_high_prob': 0,
                                'correct_low_prob': 0,
                                'wrong_count': 0,
                                'wrong_length': 0,
                                'wrong_high_prob': 0,
                                'wrong_low_prob': 0,
                            }
                        
                        # 更新总体统计
                        data_source_stats[data_source]['total_count'] += 1
                        data_source_stats[data_source]['total_length'] += total_tokens
                        data_source_stats[data_source]['total_high_prob'] += high_prob_tokens
                        data_source_stats[data_source]['total_low_prob'] += low_prob_tokens
                        
                        # 按正确性分类统计
                        if is_correct_sample:
                            data_source_stats[data_source]['correct_count'] += 1
                            data_source_stats[data_source]['correct_length'] += total_tokens
                            data_source_stats[data_source]['correct_high_prob'] += high_prob_tokens
                            data_source_stats[data_source]['correct_low_prob'] += low_prob_tokens
                        else:
                            data_source_stats[data_source]['wrong_count'] += 1
                            data_source_stats[data_source]['wrong_length'] += total_tokens
                            data_source_stats[data_source]['wrong_high_prob'] += high_prob_tokens
                            data_source_stats[data_source]['wrong_low_prob'] += low_prob_tokens
                    
                    except Exception as e:
                        print(f"处理样本时出错: {e}")
            
            reward_tensor_lst.append(reward_tensor)
            data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))

        self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

        reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu()  # (batch_size,)
        data_sources = np.concatenate(data_source_lst, axis=0)

        # evaluate test_score based on data source
        data_source_reward = {}
        for i in range(reward_tensor.shape[0]):
            data_source = data_sources[i]
            if data_source not in data_source_reward:
                data_source_reward[data_source] = []
            data_source_reward[data_source].append(reward_tensor[i].item())

        metric_dict = {}
        for data_source, rewards in data_source_reward.items():
            metric_dict[f'acc/{data_source}'] = np.mean(rewards)
        
        # 计算并添加新的统计指标到metric_dict
        for data_source, stats in data_source_stats.items():
            # 计算平均值，避免除零错误
            avg_total_length = stats['total_length'] / max(1, stats['total_count'])
            avg_total_high_prob = stats['total_high_prob'] / max(1, stats['total_count'])
            avg_total_low_prob = stats['total_low_prob'] / max(1, stats['total_count'])
            
            avg_correct_length = stats['correct_length'] / max(1, stats['correct_count'])
            avg_correct_high_prob = stats['correct_high_prob'] / max(1, stats['correct_count'])
            avg_correct_low_prob = stats['correct_low_prob'] / max(1, stats['correct_count'])
            
            avg_wrong_length = stats['wrong_length'] / max(1, stats['wrong_count']) 
            avg_wrong_high_prob = stats['wrong_high_prob'] / max(1, stats['wrong_count'])
            avg_wrong_low_prob = stats['wrong_low_prob'] / max(1, stats['wrong_count'])
            
            # 添加到指标字典
            prefix = f'length/{data_source}'
            metric_dict[f'{prefix}/avg_length'] = avg_total_length
            metric_dict[f'{prefix}/avg_high_prob'] = avg_total_high_prob
            metric_dict[f'{prefix}/avg_low_prob'] = avg_total_low_prob
            
            metric_dict[f'{prefix}/correct_avg_length'] = avg_correct_length
            metric_dict[f'{prefix}/correct_avg_high_prob'] = avg_correct_high_prob
            metric_dict[f'{prefix}/correct_avg_low_prob'] = avg_correct_low_prob
            
            metric_dict[f'{prefix}/wrong_avg_length'] = avg_wrong_length
            metric_dict[f'{prefix}/wrong_avg_high_prob'] = avg_wrong_high_prob
            metric_dict[f'{prefix}/wrong_avg_low_prob'] = avg_wrong_low_prob
            
            # 添加样本数量统计
            metric_dict[f'{prefix}/total_samples'] = stats['total_count']
            metric_dict[f'{prefix}/correct_samples'] = stats['correct_count'] 
            metric_dict[f'{prefix}/wrong_samples'] = stats['wrong_count']
            
            # 添加正确率统计
            correct_ratio = stats['correct_count'] / max(1, stats['total_count'])
            metric_dict[f'{prefix}/correct_ratio'] = correct_ratio
        
        self._log_validation_sample(sample_inputs, sample_outputs, sample_scores, self.global_steps, sample_data_sources)
        
        return metric_dict
        
    def init_workers(self):
        """Init resource pool and worker group"""
        self.resource_pool_manager.create_resource_pool()

        self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}

        # create actor and rollout
        if self.hybrid_engine:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
            actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
                                                     config=self.config.actor_rollout_ref,
                                                     role='actor_rollout')
            self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
        else:
            raise NotImplementedError

        # create critic
        if self.use_critic:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
            critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
            self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls

        # create reference policy if needed
        if self.use_reference_policy:
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
            ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
                                                  config=self.config.actor_rollout_ref,
                                                  role='actor_rollout_ref')
                                                #   role='ref')
            self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls

        # create a reward model if reward_fn is None
        if self.use_rm:
            # we create a RM here
            resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
            rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
            self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls

        # initialize WorkerGroup
        # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
        # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
        # See XXXX for more information.
        all_wg = {}
        self.wg_dicts = []
        for resource_pool, class_dict in self.resource_pool_to_cls.items():
            worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
            wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
            spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
            all_wg.update(spawn_wg)
            # keep the referece of WorkerDict to support ray >= 2.31. Ref: XXXX
            self.wg_dicts.append(wg_dict)

        if self.use_critic:
            self.critic_wg = all_wg['critic']
            self.critic_wg.init_model()

        if self.use_reference_policy:
            self.ref_policy_wg = all_wg['ref']
            self.ref_policy_wg.init_model()

        if self.use_rm:
            self.rm_wg = all_wg['rm']
            self.rm_wg.init_model()

        # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
        self.actor_rollout_wg = all_wg['actor_rollout']
        self.actor_rollout_wg.init_model()

    def _save_checkpoint(self):
        # path: given_path + `/global_step_{global_steps}` + `/actor`
        local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
                                                f'global_step_{self.global_steps}')
        actor_local_path = os.path.join(local_global_step_folder, 'actor')

        actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
            self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
        self.actor_rollout_wg.save_checkpoint(actor_local_path,
                                              actor_remote_path,
                                              self.global_steps,
                                              remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

        if self.use_critic:
            critic_local_path = os.path.join(local_global_step_folder, 'critic')
            critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
                self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
            self.critic_wg.save_checkpoint(critic_local_path,
                                           critic_remote_path,
                                           self.global_steps,
                                           remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

        # save dataloader
        dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
        dataloader_state_dict = self.train_dataloader.state_dict()
        torch.save(dataloader_state_dict, dataloader_local_path)

        # latest checkpointed iteration tracker (for atomic usage)
        local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir,
                                                           'latest_checkpointed_iteration.txt')
        with open(local_latest_checkpointed_iteration, 'w') as f:
            f.write(str(self.global_steps))

    def _load_checkpoint(self):
        if self.config.trainer.resume_mode == 'disable':
            return 0

        # load from hdfs
        if self.config.trainer.default_hdfs_dir is not None:
            raise NotImplementedError('load from hdfs is not implemented yet')
        else:
            checkpoint_folder = self.config.trainer.default_local_dir  # TODO: check path
            if not os.path.isabs(checkpoint_folder):
                working_dir = os.getcwd()
                checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
            global_step_folder = find_latest_ckpt_path(checkpoint_folder)  # None if no latest

        # find global_step_folder
        if self.config.trainer.resume_mode == 'auto':
            if global_step_folder is None:
                print('Training from scratch')
                return 0
        else:
            if not (self.config.trainer.resume_from_path and global_step_folder is not None):
                assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type"
                assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps"
                global_step_folder = self.config.trainer.resume_mode
                if not os.path.isabs(global_step_folder):
                    working_dir = os.getcwd()
                    global_step_folder = os.path.join(working_dir, checkpoint_folder)
        print(f'Load from checkpoint folder: {global_step_folder}')
        # set global step
        self.global_steps = int(global_step_folder.split('global_step_')[-1])

        print(f'Setting global step to {self.global_steps}')
        print(f'Resuming from {global_step_folder}')

        actor_path = os.path.join(global_step_folder, 'actor')
        critic_path = os.path.join(global_step_folder, 'critic')
        # load actor
        self.actor_rollout_wg.load_checkpoint(actor_path,
                                              del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
        # load critic
        if self.use_critic:
            self.critic_wg.load_checkpoint(critic_path,
                                           del_local_after_load=self.config.trainer.del_local_ckpt_after_load)

        # load dataloader,
        # TODO: from remote not implemented yet
        dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
        if os.path.exists(dataloader_local_path):
            dataloader_state_dict = torch.load(dataloader_local_path)
            self.train_dataloader.load_state_dict(dataloader_state_dict)
        else:
            print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")

    def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
        """Reorder the data on single controller such that each dp rank gets similar total tokens"""
        attention_mask = batch.batch['attention_mask']
        batch_size = attention_mask.shape[0]
        global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist()  # (train_batch_size,)
        world_size = self.actor_rollout_wg.world_size
        global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst,
                                                              k_partitions=world_size,
                                                              equal_size=True)
        # reorder based on index. The data will be automatically equally partitioned by dispatch function
        global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
        batch.reorder(global_idx)
        global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst,
                                                    partitions=global_partition_lst,
                                                    prefix=logging_prefix)
        metrics.update(global_balance_stats)
    
    def _log_validation_sample(self, sample_inputs, sample_outputs, sample_scores, global_steps, data_sources=None):
        """
        按照data_source分组保存验证样本
        
        Args:
            sample_inputs: 输入文本列表
            sample_outputs: 输出文本列表
            sample_scores: 得分列表
            global_steps: 全局步数
            data_sources: 数据源列表，与sample_inputs等长度相同
        """
        output_dir = os.path.join(self.config.trainer.default_local_dir, 'val')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        
        # 如果未提供data_sources参数，则默认全部为'unknown'
        if data_sources is None:
            data_sources = ['unknown'] * len(sample_inputs)
        
        # 按照data_source分组
        data_source_samples = {}
        for input_text, output_text, score, data_source in zip(sample_inputs, sample_outputs, sample_scores, data_sources):
            if data_source not in data_source_samples:
                data_source_samples[data_source] = []
            data_source_samples[data_source].append({
                'input': input_text,
                'output': output_text,
                'score': score
            })
        
        # 为每个data_source创建单独的文件
        for data_source, samples in data_source_samples.items():
            # 创建干净的文件名（去除可能的特殊字符）
            safe_data_source = ''.join(c if c.isalnum() else '_' for c in data_source)
            output_path = os.path.join(output_dir, f'val_sample_{safe_data_source}_global_step_{global_steps}.jsonl')
            
            with open(output_path, 'w') as f:
                for sample in samples:
                    f.write(json.dumps(sample))
                    f.write('\n')
            
            print(f"已保存 {len(samples)} 个样本到 {output_path}")
        
        # 也保存一个包含所有样本的文件，保持向后兼容
        all_output_path = os.path.join(output_dir, f'val_sample_global_step_{global_steps}.jsonl')
        with open(all_output_path, 'w') as f:
            for input_text, output_text, score in zip(sample_inputs, sample_outputs, sample_scores):
                f.write(json.dumps({'input': input_text, 'output': output_text, 'score': score}))
                f.write('\n')

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from verl.utils.tracking import Tracking
        from omegaconf import OmegaConf

        logger = Tracking(project_name=self.config.trainer.project_name,
                          experiment_name=self.config.trainer.experiment_name,
                          default_backend=self.config.trainer.logger,
                          config=OmegaConf.to_container(self.config, resolve=True))

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
            val_metrics = self._validate()
            pprint(f'Initial validation metrics: {val_metrics}')
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get('val_only', False):
                return

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        for epoch in range(self.config.trainer.total_epochs):     
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}

                batch: DataProto = DataProto.from_single_dict(batch_dict)

                # pop those keys for generation
                if 'multi_modal_inputs' in batch.non_tensor_batch.keys():
                    gen_batch = batch.pop(
                        batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                        non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
                    )
                else:
                    gen_batch = batch.pop(
                        batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                        non_tensor_batch_keys=['raw_prompt_ids'],
                    )

                is_last_step = self.global_steps >= self.total_training_steps

                with _timer('step', timing_raw):
                    # generate a batch
                    with _timer('gen', timing_raw):
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with _timer('gen_max', timing_raw):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info['do_sample'] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            batch = batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            batch.batch['reward_baselines'] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
                                                             dtype=object)
                    # repeat to align with repeated responses in rollout
                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)

                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

                    # recompute old_log_probs
                    with _timer('old_log_prob', timing_raw):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        batch = batch.union(old_log_prob)

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with _timer('ref', timing_raw):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # compute values
                    if self.use_critic:
                        with _timer('values', timing_raw):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with _timer('adv', timing_raw):
                        # compute scores. Support both model and function-based.
                        # We first compute the scores using reward model. Then, we call reward_fn to combine
                        # the results from reward model and rule-based results.
                        if self.use_rm:
                            # we first compute reward model score
                            reward_tensor = self.rm_wg.compute_rm_score(batch)
                            batch = batch.union(reward_tensor)

                        # we combine with rule-based rm
                        reward_tensor = self.reward_fn(batch)
                        batch.batch['token_level_scores'] = reward_tensor

                        # compute rewards. apply_kl_penalty if available
                        if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
                            batch, kl_metrics = apply_kl_penalty(batch,
                                                                 kl_ctrl=self.kl_ctrl,
                                                                 kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)
                        else:
                            batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

                        # compute advantages, executed on the driver process
                        # batch = compute_advantage(batch,
                        #                           adv_estimator=self.config.algorithm.adv_estimator,
                        #                           gamma=self.config.algorithm.gamma,
                        #                           lam=self.config.algorithm.lam,
                        #                           num_repeat=self.config.actor_rollout_ref.rollout.n)
                        
                                                
                        # 仅当使用自定义公式时才需要（或你也可选择一直计算 ref 的输出）
                        if self.config.algorithm.adv_estimator == AdvantageEstimator.LGAE:
                            if self.use_reference_policy:
                                print("====================generate ref response=====================")
                                # 准备用于 ref model 的输入（类似 actor 的输入）
                                batch_ref = DataProto.from_single_dict(batch_dict)

                                ref_gen_batch = batch_ref.pop(
                                    batch_keys=['input_ids', 'attention_mask', 'position_ids'],
                                    non_tensor_batch_keys=['raw_prompt_ids']
                                )
                                # 调用 ref policy 生成 y'
                                ref_gen_output = self.ref_policy_wg.generate_sequences(ref_gen_batch)
                                # 将 ref model 的输出合并到 batch
                                batch_ref.batch['ref_responses'] = ref_gen_output.batch['responses']
                                batch_ref.batch['ref_attention_mask'] = ref_gen_output.batch['attention_mask']
                                batch_ref = batch_ref.union(ref_gen_output)

                                with _timer('ref_old_log_prob', timing_raw):
                                    ref_old_log_prob = self.ref_policy_wg.compute_log_prob(batch_ref)
                                    batch_ref = batch_ref.union(ref_old_log_prob)

                                # 计算 ref model 的 advantage (对 y')。这里简单用 GAE 举例
                                # 1) 先计算 ref_values
                                ref_values = self.critic_wg.compute_values(batch_ref) if self.use_critic else None
                                if ref_values is not None:
                                    batch_ref = batch_ref.union(ref_values)

                                else:
                                    print("Warning: no critic found. Can't compute ref_advantages. Will skip.")
                            else:
                                print("Warning: adv_estimator=CUSTOM_FORMULA but no reference policy found.")
                            
                            # ★★★ 结束 ref model 生成和 advantage 计算 ★★★

                            # batch.batch['ref_advantages'] = batch_ref.batch['ref_advantages']
                            batch.batch['ref_responses'] = batch_ref.batch['ref_responses']
                            batch.batch['ref_attention_mask'] = batch_ref.batch['ref_attention_mask']
                            batch.batch['ref_old_log_probs'] = batch_ref.batch['old_log_probs']
                            # batch.batch['token_level_rewards'] = final_token_level_rewards

                        # 2) 调用 compute_advantage 计算最终 advantage
                        batch = compute_advantage(batch,
                                                adv_estimator=self.config.algorithm.adv_estimator,
                                                gamma=self.config.algorithm.gamma,
                                                lam=self.config.algorithm.lam,
                                                num_repeat=self.config.actor_rollout_ref.rollout.n,
                                                lambda_c=self.config.algorithm.lambda_c,
                                                lambda_w_n=self.config.algorithm.lambda_w_n,
                                                lambda_w_e=self.config.algorithm.lambda_w_e,
                                                tokenizer=self.tokenizer,
                                                i_threshold=self.config.algorithm.i_threshold,
                                                initial_slope=self.initial_slope,
                                                dynamic_lambda_w_e=self.dynamic_lambda_w_e,
                                                ppo=self.config.algorithm.ppo)

                    # update critic
                    if self.use_critic:
                        with _timer('update_critic', timing_raw):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                        metrics.update(critic_output_metrics)

                    # implement critic warmup
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # update actor
                        with _timer('update_actor', timing_raw):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                        metrics.update(actor_output_metrics)
                    
                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
                        (is_last_step or  self.global_steps % self.config.trainer.test_freq == 0):
                        with _timer('testing', timing_raw):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and ( is_last_step or \
                            self.global_steps % self.config.trainer.save_freq == 0):
                        with _timer('save_checkpoint', timing_raw):
                            self._save_checkpoint()

                # collect metrics
                if 'metrics' in batch.meta_info:
                    metrics.update(batch.meta_info['metrics'])
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                
                # 从metrics中获取训练分数
                print("Available metrics keys:", list(metrics.keys()))
                training_score = metrics.get('critic/score/mean', None)
                if training_score is not None:
                    # 存储训练分数和对应步数
                    if not hasattr(self, 'training_scores'):
                        self.training_scores = []
                        self.training_steps = []
                    
                    self.training_scores.append(training_score)
                    self.training_steps.append(self.global_steps)
                    
                    # 当步数大于100时开始动态调整斜率
                    if self.global_steps >= self.config.algorithm.slope_start_epoch and len(self.training_scores) >= 2:
                        # 计算从当前step往前50个step的斜率（或者可用的全部step）
                        look_back = min(self.config.algorithm.slope_period, len(self.training_scores) - 1)
                        x = self.training_steps[-look_back:]
                        y = self.training_scores[-look_back:]
                        
                        # 线性回归：拟合 y = mx + b
                        X = np.array(x).reshape(-1, 1)
                        Y = np.array(y)
                        
                        model = LinearRegression().fit(X, Y)
                        slope = model.coef_[0]
                        
                        # 首次计算斜率时，设置初始斜率和尺度
                        if self.initial_slope is None:
                            self.initial_slope = slope
                            # 设置尺度，将初始斜率映射到1
                            if self.initial_slope == 0:
                                self.slope_scale = self.config.algorithm.lambda_w_e / (math.pi / 2)  # 如果初始斜率为0，设置一个默认尺度
                            else:
                                self.slope_scale = (self.config.algorithm.lambda_w_e / (math.pi / 2)) / max(1e-6, abs(self.initial_slope))
                            print(f"\n===== 初始斜率设置 =====")
                            print(f"初始斜率: {self.initial_slope:.8f}")
                            print(f"设置尺度: {self.slope_scale:.8f}")
                        
                        # 按相同尺度映射当前斜率
                        scaled_slope = slope * self.slope_scale
                        
                        print(f"\n===== 斜率映射 =====")
                        print(f"原始斜率: {slope:.8f}")
                        print(f"映射后斜率: {scaled_slope:.4f} (初始斜率映射为{self.config.algorithm.lambda_w_e/(math.pi/2):.2f})")
                        print(f"计算使用了最近{look_back}个步骤的数据")
                        
                        # 处理映射后斜率小于等于0.5的情况
                        if scaled_slope <= self.config.algorithm.slope_threshold:
                            # 增加低斜率计数
                            self.low_slope_count += 1
                            
                            # 线性减小lambda值，每次减小0.5，最大减到-lambda_w_e
                            negative_lambda = -0.5 * self.low_slope_count
                            self.dynamic_lambda_w_e = max(-2 * self.config.algorithm.lambda_w_e / (math.pi / 2), negative_lambda)
                            
                            print(f"映射后斜率 <= {self.config.algorithm.slope_threshold}，连续低斜率次数: {self.low_slope_count}")
                            print(f"lambda_w_e线性递减至: {self.dynamic_lambda_w_e:.4f}")
                        else:
                            # 如果斜率恢复，重置低斜率计数
                            self.low_slope_count = 0
                            # 使用映射后的斜率作为lambda_w_e
                            self.dynamic_lambda_w_e = scaled_slope
                            print(f"使用映射后斜率作为lambda_w_e: {self.dynamic_lambda_w_e:.4f}")
                        
                        # 将相关信息添加到指标中
                        metrics['lambda_w_e/train_raw_slope'] = slope
                        metrics['lambda_w_e/train_scaled_slope'] = scaled_slope
                        metrics['lambda_w_e/train_value'] = self.dynamic_lambda_w_e
                        metrics['lambda_w_e/train_low_slope_count'] = self.low_slope_count
                
                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if is_last_step:
                    pprint(f'Final validation metrics: {last_val_metrics}')
                    return

                self.global_steps += 1
