# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint

import numpy as np
import torch
from tqdm import tqdm
from codetiming import Timer
from verl import DataProto
from verl.trainer.ppo.metric_utils import (
    compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
    reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask

from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
from verl.utils.save_dataproto_plaintext import save_plaintext_to_disk
import os
from verl.utils.stat_utils import compute_entropy_statistics, compute_kurtosis_vectorized


def append_wait_and_regenerate(raw_batch, tokenizer, actor_rollout_wg, max_token_num):
    """
    Add "wait" character to all outputs and regenerate, maintaining consistency with original data processing flow
    
    Args:
        raw_batch: DataProto object containing current generation results
        raw_batch.batch: (debug example)
            TensorDict(
                fields={
                    attention_mask: Tensor(shape=torch.Size([8, 2048]), device=cpu, dtype=torch.int64, is_shared=True),
                    input_ids: Tensor(shape=torch.Size([8, 2048]), device=cpu, dtype=torch.int64, is_shared=True),
                    position_ids: Tensor(shape=torch.Size([8, 2048]), device=cpu, dtype=torch.int64, is_shared=True)},
                batch_size=torch.Size([8]),
                device=None,
                is_shared=False)
        tokenizer: Tokenizer for encoding/decoding
        actor_rollout_wg: Actor rollout worker group for regeneration
    
    Returns:
        DataProto: DataProto object containing newly generated results
        new_batch.batch: (debug example)
            TensorDict(
                fields={
                    attention_mask: Tensor(shape=torch.Size([64, 2112]), device=cpu, dtype=torch.int64, is_shared=False),
                    input_ids: Tensor(shape=torch.Size([64, 2112]), device=cpu, dtype=torch.int64, is_shared=False),
                    position_ids: Tensor(shape=torch.Size([64, 2112]), device=cpu, dtype=torch.int64, is_shared=False),
                    prompts: Tensor(shape=torch.Size([64, 2048]), device=cpu, dtype=torch.int64, is_shared=False),
                    responses: Tensor(shape=torch.Size([64, 64]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([64]),
                device=cpu,
                is_shared=False)
    """
    # # 1. Extract text from complete input_ids
    # input_ids = batch.batch['input_ids']  # shape: [64, 2112]
    # prompt_ids_batch = batch.batch["prompts"] # shape: [64, 2048]
    # attention_mask_batch = batch.batch['attention_mask']  # shape: [64, 2112]
    
    # deepcopy batch
    batch = deepcopy(raw_batch)
    batch_max_length = min(batch.batch['attention_mask'].sum(dim=-1).max().item() + 100, max(max_token_num, 64))
    print(f"my_debug: batch_max_length={batch_max_length}")

    # Add "wait" to each sequence
    new_input_ids_list = []
    new_attention_mask_list = []
    new_position_ids_list = []
    for i in range(len(batch)):
        data_item = batch[i]  # DataProtoItem

        prompt_ids = data_item.batch["prompts"] # [prompt_seq_len]

        prompt_length = prompt_ids.shape[-1] 

        valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
        valid_prompt_ids = prompt_ids[-valid_prompt_length:] #[valid_prompt_seq_len]

        response_ids = data_item.batch["responses"] # [response_seq_len]
        valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
        valid_response_ids = response_ids[:valid_response_length] #[valid_response_seq_len]

        # # decode
        # prompt_str = tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
        # response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)

        # Process prompt_ids and response_ids separately to prevent inconsistency after decode-encode
        # Add wait to response_str and get new response_ids
        # Calculate length of Wait token
        wait_token_ids = tokenizer.encode("\nWait\n", add_special_tokens=False)
        wait_length = len(wait_token_ids)
        
        # import madbg; madbg.set_trace(ip='0.0.0.0', port=1337)
        # More strict length check
        if valid_prompt_length + valid_response_length + wait_length > batch_max_length - 10:
            # If adding Wait would exceed length, truncate response
            print(f"my_debug: valid_prompt_length + valid_response_length + wait_length > batch_max_length - 10")
            new_valid_response_ids = valid_response_ids[:batch_max_length - valid_prompt_length - wait_length - 10]
            new_valid_response_ids = torch.tensor(new_valid_response_ids, device=prompt_ids.device)
        else:
            add_token_ids = torch.tensor(wait_token_ids, device=prompt_ids.device)
            if valid_response_ids[-1].item() == tokenizer.eos_token_id:
                print(f"my_debug: last valid_response_id is eos token {valid_response_ids[-1].item() == tokenizer.eos_token_id}")
                new_valid_response_ids = torch.cat([valid_response_ids[:-1], add_token_ids], dim=-1)
            
            # Check if last token is eos token
            if valid_response_ids[-1].item() != tokenizer.eos_token_id:
                print(f"my_warning: last token is not eos token! last_token_id={valid_response_ids[-1].item()}, eos_token_id={tokenizer.eos_token_id}")
                new_valid_response_ids = torch.cat([valid_response_ids, add_token_ids], dim=-1)
        
        # Check total length again
        total_length = valid_prompt_length + new_valid_response_ids.shape[-1]
        if total_length > batch_max_length:
            # If still exceeding length, truncate
            new_valid_response_ids = new_valid_response_ids[:batch_max_length - valid_prompt_length]
        
        # Get new input_ids and attention_mask
        new_input_ids = torch.cat([valid_prompt_ids, new_valid_response_ids]).unsqueeze(0) # [1, valid_prompt_seq_len + new_valid_response_seq_len]
        new_attention_mask = torch.ones_like(new_input_ids) # [1, valid_prompt_seq_len + new_valid_response_seq_len]
        
        # Use same post-processing function
        # 
        new_input_ids, new_attention_mask = verl_F.postprocess_data(
            input_ids=new_input_ids,
            attention_mask=new_attention_mask,
            max_length=batch_max_length,  
            pad_token_id=tokenizer.pad_token_id,
            left_pad=True,
            truncation="error"
        )
        # Use same position encoding calculation method
        new_position_ids = compute_position_id_with_mask(new_attention_mask)

        new_input_ids_list.append(new_input_ids[0])
        new_attention_mask_list.append(new_attention_mask[0])
        new_position_ids_list.append(new_position_ids[0])

    # Create new gen_batch containing necessary fields and non_tensor_batch data
    new_batch = DataProto.from_dict(
        tensors={
            'input_ids': torch.stack(new_input_ids_list, dim=0), # [batch_size, seq_len]
            'attention_mask': torch.stack(new_attention_mask_list, dim=0), # [batch_size, seq_len]
            'position_ids': torch.stack(new_position_ids_list, dim=0) # [batch_size, seq_len]
        },
        non_tensors=raw_batch.non_tensor_batch
    )

    gen_batch = new_batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])

    # Regenerate
    gen_batch.meta_info["n"] = 1
    gen_batch_output = actor_rollout_wg.generate_sequences(gen_batch)
    new_batch = new_batch.union(gen_batch_output)
    
    return new_batch

def restore_to_original_format(extended_batch, origin_batch, tokenizer, max_response_length_final):
    """
    Restore multi-inference results to single-inference format using original prompt

    Args:
        extended_batch: DataProto object after multi-inference
        extended_batch.batch: (debug example)
            TensorDict(
                fields={
                    attention_mask: Tensor(shape=torch.Size([64, 925]), device=cpu, dtype=torch.int64, is_shared=False),
                    input_ids: Tensor(shape=torch.Size([64, 925]), device=cpu, dtype=torch.int64, is_shared=False),
                    position_ids: Tensor(shape=torch.Size([64, 925]), device=cpu, dtype=torch.int64, is_shared=False),
                    prompts: Tensor(shape=torch.Size([64, 861]), device=cpu, dtype=torch.int64, is_shared=False),
                    responses: Tensor(shape=torch.Size([64, 64]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([64]),
                device=None,
                is_shared=False)
        origin_batch: Original single-inference DataProto object
        origin_batch.batch: (debug example)
            TensorDict(
                fields={
                    attention_mask: Tensor(shape=torch.Size([64, 2048]), device=cpu, dtype=torch.int64, is_shared=True),
                    input_ids: Tensor(shape=torch.Size([64, 2048]), device=cpu, dtype=torch.int64, is_shared=True),
                    position_ids: Tensor(shape=torch.Size([64, 2048]), device=cpu, dtype=torch.int64, is_shared=True)},
                batch_size=torch.Size([64]),
                device=None,
                is_shared=False)
        tokenizer: Tokenizer for encoding/decoding
        max_response_length: Maximum response length

    Returns:
        DataProto: Restored DataProto object
    """
    # Extract prompt and response from extend_batch
    assert len(extended_batch) == len(origin_batch)
    valid_response_ids_final_list = []
    valid_response_attention_mask_final_list = []
    for i in range(len(extended_batch)):
        data_item_origin = origin_batch[i] # DataProtoItem
        data_item_extended = extended_batch[i] # DataProtoItem


        # Verify consistency between original prompt and extended prompt
        input_ids_origin = data_item_origin.batch['input_ids']
        prompt_length_origin = input_ids_origin.shape[-1]
        valid_prompt_length_origin = data_item_origin.batch['attention_mask'][:prompt_length_origin].sum()
        valid_prompt_ids_origin = input_ids_origin[-valid_prompt_length_origin:]

        input_ids_extended = data_item_extended.batch['input_ids']
        # extract valid input_ids by attention_mask
        valid_input_ids_extended = input_ids_extended[data_item_extended.batch['attention_mask'] == 1]

        assert torch.equal(valid_prompt_ids_origin, valid_input_ids_extended[:valid_prompt_length_origin])

        valid_response_ids_final = valid_input_ids_extended[valid_prompt_length_origin:]
        valid_response_ids_final = valid_response_ids_final.unsqueeze(0) # [1, seq_len]
        valid_response_attention_mask_final = torch.ones_like(valid_response_ids_final) # [1, seq_len]

        # Restore to original format based on valid_prompt_ids_origin and valid_response_ids_final
        # prompt can use origin_batch.batch['prompts'], need valid_response_ids_final right pad 
        valid_response_ids_final_list.append(valid_response_ids_final)
        valid_response_attention_mask_final_list.append(valid_response_attention_mask_final)

    # Items in valid_response_ids_final_list have different lengths, right pad
    pad_valid_response_ids_final_list = []
    pad_valid_response_attention_mask_final_list = []
    # max_response_length_final = max(item.shape[-1] for item in valid_response_ids_final_list)
    # assert max_response_length_final <= max_response_length
    for valid_response_ids_final, valid_response_attention_mask_final in zip(valid_response_ids_final_list, valid_response_attention_mask_final_list):
        # use verl_F.postprocess_data to pad right
        pad_valid_response_ids_final, pad_valid_response_attention_mask_final = verl_F.postprocess_data(
            input_ids=valid_response_ids_final,
            attention_mask=valid_response_attention_mask_final,
            max_length=max_response_length_final,  
            pad_token_id=tokenizer.pad_token_id,
            left_pad=False,
            truncation="error"
        )
        pad_valid_response_ids_final_list.append(pad_valid_response_ids_final[0])
        pad_valid_response_attention_mask_final_list.append(pad_valid_response_attention_mask_final[0])
    # to tensor
    pad_valid_response_ids_final_tensor = torch.stack(pad_valid_response_ids_final_list, dim=0) # [batch_size, seq_len]
    pad_valid_response_attention_mask_final_tensor = torch.stack(pad_valid_response_attention_mask_final_list, dim=0) # [batch_size, seq_len]

    # Restore to single-inference format
    input_ids = torch.cat([origin_batch.batch['input_ids'], pad_valid_response_ids_final_tensor], dim=1)
    attention_mask = torch.cat([origin_batch.batch['attention_mask'], pad_valid_response_attention_mask_final_tensor], dim=1)
    position_ids = compute_position_id_with_mask(attention_mask)
    restored_batch = DataProto.from_dict(
        tensors={
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'position_ids': position_ids,
            'prompts': origin_batch.batch['input_ids'],
            'responses': pad_valid_response_ids_final_tensor
        },
        non_tensors=extended_batch.non_tensor_batch
    )
    return restored_batch


def normalize_to_open_interval(values, epsilon=1e-6):
    """
    Normalize values to open interval (0, 1), preserving ordering properties
    Args:
        values: Input tensor (torch.Tensor)
        epsilon: Small value to avoid reaching 0 or 1, default 1e-6
    Returns:
        Normalized tensor with range (epsilon, 1-epsilon)
    Example:
        >>> values = torch.tensor([1.0, 5.0, 3.0, 10.0])
        >>> normalized = normalize_to_open_interval(values)
        >>> print(normalized)  # Output similar to: [0.000001, 0.444445, 0.222223, 0.999999]
        >>> # Preserve ordering: 1.0 < 3.0 < 5.0 < 10.0 corresponds to 0.000001 < 0.222223 < 0.444445 < 0.999999
    """
    min_val = values.min()
    max_val = values.max()
    # Handle boundary case where all values are the same
    if min_val == max_val:
        return torch.full_like(values, 0.5)
    normalized = (values - min_val) / (max_val - min_val)# Min-Max normalization to [0, 1]
    # Adjust to open interval (epsilon, 1-epsilon); formula: y = x * (1 - 2*epsilon) + epsilon; when x=0, y=epsilon; when x=1, y=1-epsilon
    normalized = normalized * (1 - 2 * epsilon) + epsilon
    return normalized


class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        timing_raw = defaultdict(float)
        batch = None
        num_prompt_in_batch = 0
        num_gen_batches = 0
        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                metrics = {}

                new_batch: DataProto = DataProto.from_single_dict(batch_dict)
                num_gen_batches += 1
                # pop those keys for generation
                if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
                    )
                else:
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"], # raw_prompt_ids is actually discarded
                    )

                is_last_step = self.global_steps >= self.total_training_steps

                with _timer("step", timing_raw):
                    # generate a batch
                    # with _timer("gen", timing_raw):
                    #     gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                    with _timer("gen", timing_raw):
                        if not self.async_rollout_mode:
                            if self.config.actor_rollout_ref.rollout.topn_metric:
                                print(f"my_debug: topn_metric=True ...")
                                duplicate_times = 3
                                gen_batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n * duplicate_times
                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                                del gen_batch.meta_info["n"]
                                # gen_batch_output.meta_info['train_mode'] = True
                                # actor_output_entropy = self.actor_rollout_wg.compute_entropy_for_every_source(data=gen_batch_output)
                                old_log_probs = self.actor_rollout_wg.compute_log_prob(data=gen_batch_output).batch['old_log_probs'].detach()
                                response_seq_len = gen_batch_output.batch['responses'].shape[-1]
                                response_mask = gen_batch_output.batch['attention_mask'][:, -response_seq_len:]
                                prob_kurtosis = compute_kurtosis_vectorized(torch.exp(old_log_probs), response_mask)
                                
                                # Group gen_batch_output by duplicate_times*n samples
                                group_size = self.config.actor_rollout_ref.rollout.n
                                batch_size = len(gen_batch_output)
                                num_groups = batch_size // (duplicate_times * group_size)
                                
                                selected_indices = []
                                for group_idx in range(num_groups):
                                    start_idx = group_idx * duplicate_times * group_size
                                    end_idx = start_idx + duplicate_times * group_size
                                    # Get prob_kurtosis for current group
                                    group_prob_kurtosis = prob_kurtosis[start_idx:end_idx]
                                    
                                    # Handle outliers: replace NaN and infinity with maximum finite value
                                    finite_mask = torch.isfinite(group_prob_kurtosis)
                                    if not finite_mask.all():
                                        invalid_count = (~finite_mask).sum().item()
                                        print(f"Warning: Found {invalid_count} NaN/inf values in group {group_idx}, replacing with max")
                                        if finite_mask.any():
                                            max_val = group_prob_kurtosis[finite_mask].max()
                                            group_prob_kurtosis = group_prob_kurtosis.clone()
                                            group_prob_kurtosis[~finite_mask] = max_val
                                        else:
                                            group_prob_kurtosis = torch.ones_like(group_prob_kurtosis)  # Set to 1 when all values are outliers
                                    
                                    # Normalize to open interval (0, 1), preserving ordering properties
                                    group_prob_kurtosis = normalize_to_open_interval(group_prob_kurtosis, epsilon=1e-6)
                                    
                                    # Weight by prob_kurtosis to get weights summing to 1, then sample n samples without replacement according to this probability
                                    # tmp_weights = (1.0 - group_prob_kurtosis)
                                    tmp_weights = 1.0 - group_prob_kurtosis
                                    tmp_weights = tmp_weights / tmp_weights.sum()  # Normalize to probability distribution
                                    selected_indices_group = torch.multinomial(tmp_weights, num_samples=group_size, replacement=False)
                                    # Convert relative indices to global indices
                                    global_indices = start_idx + selected_indices_group
                                    selected_indices.extend(global_indices.tolist())
                                
                                # Reconstruct gen_batch_output using selected indices
                                gen_batch_output = gen_batch_output[selected_indices]

                            else:
                                print(f"my_debug: vanilla generate_sequences ...")
                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                        else:
                            self.async_rollout_manager.wake_up()
                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
                            self.async_rollout_manager.sleep()
                            
                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with _timer("gen_max", timing_raw):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            new_batch = new_batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(new_batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            new_batch.batch["reward_baselines"] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
                    # repeat to align with repeated responses in rollout
                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    new_batch = new_batch.union(gen_batch_output)
                    
                    

                    if self.config.actor_rollout_ref.rollout.extend_context_k > 0:
                        # print(f"my_debug: extend_context_idx=-1 ...")
                        # old_batch = deepcopy(new_batch)
                        # old_batch.meta_info['train_mode'] = True
                        # actor_output_entropy = self.actor_rollout_wg.compute_entropy_for_every_source(data=old_batch)
                        # actor_output_entropy_final = actor_output_entropy.batch['output_entropy_loss_val'] # (batch_size)
                        # print(f"my_debug: old_batch entropy: {actor_output_entropy_final.mean().item()}")
                        # print("--------------------------------")
                        # Preserve original gen_batch for restoring final prompt+multi-turn response to original prompt+response
                        origin_gen_batch = deepcopy(gen_batch)
                        origin_gen_batch = origin_gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

                        for extend_context_idx in range(self.config.actor_rollout_ref.rollout.extend_context_k):
                            
                            new_batch = append_wait_and_regenerate(
                                raw_batch=new_batch, 
                                tokenizer=self.tokenizer,
                                actor_rollout_wg=self.actor_rollout_wg,
                                max_token_num=self.config.data.max_prompt_length+self.config.data.max_response_length * (1 + self.config.actor_rollout_ref.rollout.extend_context_k) + 32
                            )

                        # Restore to original format
                        new_batch = restore_to_original_format(
                            extended_batch=new_batch,
                            origin_batch=origin_gen_batch,
                            tokenizer=self.tokenizer,
                            max_response_length_final=(self.config.data.max_response_length + 32) * (1 + self.config.actor_rollout_ref.rollout.extend_context_k)
                        )

                        data_item = new_batch[0]
                        prompt_ids = data_item.batch['prompts']
                        prompt_length = prompt_ids.shape[-1]
                        valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
                        valid_prompt_ids = prompt_ids[-valid_prompt_length:]
                        response_ids = data_item.batch['responses']
                        valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
                        valid_response_ids = response_ids[:valid_response_length]
                        prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
                        response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
                        print(f"my_debug: {prompt_str=}")
                        print(f"my_debug: {response_str=}")

                    with _timer("reward", timing_raw):
                        # compute scores. Support both model and function-based.
                        # We first compute the scores using reward model. Then, we call reward_fn to combine
                        # the results from reward model and rule-based results.
                        if self.use_rm:
                            # we first compute reward model score
                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)
                            new_batch = new_batch.union(reward_tensor)

                        # we combine with rule-based rm
                        reward_extra_infos_dict: dict[str, list]
                        try:
                            reward_result = self.reward_fn(new_batch, return_dict=True)
                            reward_tensor = reward_result["reward_tensor"]
                            reward_extra_infos_dict = reward_result["reward_extra_info"]
                        except Exception as e:
                            print(f"Error in reward_fn: {e}")
                            reward_tensor = self.reward_fn(new_batch)
                            reward_extra_infos_dict = {}

                        new_batch.batch["token_level_scores"] = reward_tensor

                        print(f"{list(reward_extra_infos_dict.keys())=}")
                        if reward_extra_infos_dict:
                            new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        # compute rewards. apply_kl_penalty if available
                        if self.config.algorithm.use_kl_in_reward:
                            new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)  # TODO: This will be cleared if we use multiple genenration batches
                        else:
                            new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

                    if not self.config.algorithm.filter_groups.enable:
                        batch = new_batch
                    else:  # NOTE: When prompts after filtering is less than train batch size,
                        # we skip to the next generation batch
                        metric_name = self.config.algorithm.filter_groups.metric
                        if metric_name == "seq_final_reward":
                            # Turn to numpy for easier filtering
                            new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
                        elif metric_name == "seq_reward":
                            new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()

                        # Collect the sequence reward for each trajectory
                        prompt_uid2metric_vals = defaultdict(list)
                        for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
                            prompt_uid2metric_vals[uid].append(metric_val)

                        prompt_uid2metric_std = {}
                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

                        kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
                        num_prompt_in_batch += len(kept_prompt_uids)

                        kept_traj_idxs = []
                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
                            if traj_from_prompt_uid in kept_prompt_uids:
                                kept_traj_idxs.append(idx)

                        new_batch = new_batch[kept_traj_idxs]
                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
                        print(f"my_debug: batch.batch['responses'].shape={batch.batch['responses'].shape}")

                        prompt_bsz = self.config.data.train_batch_size
                        if num_prompt_in_batch < prompt_bsz:
                            print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                                print(f"{num_gen_batches=}. Keep generating...")
                                continue
                            else:
                                raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
                        else:
                            # Align the batch
                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
                            print(f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. Collecting finished.")
                            batch = batch[:traj_bsz]

                    # === Updating ===

                    batch.batch["response_mask"] = compute_response_mask(batch)
                    
                    # balance the number of valid tokens on each dp rank.
                    # Note that this breaks the order of data inside the batch.
                    # Please take care when you implement group based adv computation such as GRPO and rloo
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)
                    
                    # todo: compute train entropy
                    with Timer(name='compute_all_entropy', text="{name}: {seconds:.1f} seconds") as timer:
                        batch.meta_info['is_filtered'] = False
                        batch.meta_info['train_mode'] = False
                        batch.meta_info['val_set'] = False
                        data_source_lst = []
                        data_source_reward = {}
                        data_source_entropy = {}
                      
                        actor_output_entropy = self.actor_rollout_wg.compute_entropy_for_every_source(data=batch)
                        actor_output_entropy_final = actor_output_entropy.batch['output_entropy_loss_val'] 
                        
                        batch.batch['entropy'] = actor_output_entropy.batch['output_entropy_val']
                        batch.batch['entropy_loss'] = actor_output_entropy.batch['output_entropy_loss_val']

                        responses = batch.batch['responses']
                        response_length = responses.size(1)
                        attention_mask = batch.batch['attention_mask']
                        response_mask = attention_mask[:, -response_length:]
                        

                        # New code: filter entropy and count proportions in different intervals
                        entropy = batch.batch['entropy']
                        masked_entropy = entropy[response_mask.bool()]

                        # Define intervals
                        # bins = [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
                        # hist, _ = torch.histogram(masked_entropy, bins=torch.tensor(bins, dtype=torch.float32, device=masked_entropy.device))
                        # total = masked_entropy.numel()
                        # ratios = hist / total if total > 0 else torch.zeros_like(hist)

                        # # Print results
                        # print("Entropy ratios in different intervals:")
                        # for i in range(len(bins) - 1):
                        #     metrics[f"train_entropy_ratio/{bins[i]:.2f}-{bins[i + 1]:.2f}"].append(ratios[i].item())

                        #     print(f"{bins[i]:.2f}-{bins[i + 1]:.2f}: {ratios[i].item():.4f}")
                        # print(f"> {bins[-1]:.2f}: {1 - ratios.sum().item():.4f}")

                        # ... existing code ...

        
                        data_source_lst.append(batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))

                        data_sources = np.concatenate(data_source_lst, axis=0)
                        
                        entropy_statistics = compute_entropy_statistics(entropy, response_mask)
                        # put these statistics into train_entropy/...
                        for key, value in entropy_statistics.items():
                            metrics[f'train_entropy/{key}'] = value
        
                        for i in range(actor_output_entropy_final.shape[0]):
                            data_source = data_sources[i]
                            if data_source not in data_source_entropy:
                                data_source_entropy[data_source] = []
                                
                            data_source_entropy[data_source].append( actor_output_entropy_final[i].item())
                
                            
                        for data_source, entropy in data_source_entropy.items():
                            metrics[f'train_entropy/{data_source}'] = np.mean(entropy)
                            
                        metrics[f'train_entropy/all'] = actor_output_entropy_final.mean().item()
                
                        # metrics.update(actor_output_metrics)
                    metrics['timing/compute_all_entropy'] = timer.last

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    # recompute old_log_probs
                    with _timer("old_log_prob", timing_raw):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        batch = batch.union(old_log_prob)    
                        if self.config.actor_rollout_ref.actor.minp_old_log_prob:
                            minp_old_log_prob = self.actor_rollout_wg.compute_minp_log_prob(batch)
                            batch = batch.union(minp_old_log_prob)

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with _timer("ref", timing_raw):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # compute values
                    if self.use_critic:
                        with _timer("values", timing_raw):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with _timer("adv", timing_raw):
                        # compute advantages, executed on the driver process
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                        )

                    # update critic
                    if self.use_critic:
                        with _timer("update_critic", timing_raw):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    # implement critic warmup
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # update actor
                        with _timer("update_actor", timing_raw):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with _timer("testing", timing_raw):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                        with _timer("save_checkpoint", timing_raw):
                            self._save_checkpoint()

                # collect metrics
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                timing_raw = defaultdict(float)  # clear timing

                # save DataProto batch to self.config.trainer.default_local_dir
                if self.config.trainer.save_train_samples_freq > 0 and \
                    (is_last_step or self.global_steps == 1 or self.global_steps % self.config.trainer.save_train_samples_freq == 0):
                    train_samples_dir = self.config.trainer.train_samples_dir
                    if not os.path.exists(train_samples_dir):
                        os.makedirs(train_samples_dir)

                    batch.save_to_disk(os.path.join(train_samples_dir, f'batch_{self.global_steps}.pt'))
                    if self.config.trainer.get('save_plaintext', False):
                        save_plaintext_to_disk(batch, os.path.join(train_samples_dir, f'batch_{self.global_steps}.txt'))
                
                metrics["train/num_gen_batches"] = num_gen_batches
                batch = None
                num_prompt_in_batch = 0
                num_gen_batches = 0

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)


                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return

                progress_bar.update(1)
                self.global_steps += 1
