import json
import os
import ast
import uuid
import random
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Dict, Optional, Type
import shutil

import numpy as np
import ray
import torch
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
    compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
    process_validation_metrics,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.metric import (
    reduce_metrics,
)
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
from verl.workers.rollout.async_server import AsyncLLMServerManager

WorkerType = Type[Worker]
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, compute_advantage, compute_response_mask, apply_kl_penalty, _timer, Role, AdvantageEstimator, ResourcePoolManager

def compute_advantage_with_historical_rewards(data: DataProto, adv_estimator, norm_adv_by_std_in_grpo=True, **kwargs):
    if "response_mask" not in data.batch:
        data.batch["response_mask"] = compute_response_mask(data)
    
    if adv_estimator == AdvantageEstimator.GRPO:
        grpo_calculation_mask = data.batch["response_mask"]
        
        advantages, returns = compute_grpo_outcome_advantage_with_historical_rewards(
            token_level_rewards=data.batch["token_level_rewards"],
            response_mask=grpo_calculation_mask,
            prompts=data.batch["prompts"],
            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
            historical_rewards=kwargs.get("historical_rewards", None),
        )
        data.batch["advantages"] = advantages
        data.batch["returns"] = returns
    else:
        raise NotImplementedError
    return data

def compute_grpo_outcome_advantage_with_historical_rewards(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    prompts: torch.Tensor,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: str = True,
    historical_rewards: dict = None,
):
    scores = token_level_rewards.sum(dim=-1)

    with torch.no_grad():
        bsz = scores.shape[0]
        prompt_means = {}
        prompt_stds = {}
        
        for i in range(bsz):
            prompt_key = tuple(prompts[i][-250:].cpu().tolist())
            
            if historical_rewards is not None and prompt_key in historical_rewards:
                hist_rewards = historical_rewards[prompt_key]
                if len(hist_rewards) == 1:
                    prompt_means[prompt_key] = torch.tensor(0.0)
                    prompt_stds[prompt_key] = torch.tensor(1.0)
                else:
                    prompt_means[prompt_key] = torch.mean(torch.tensor(hist_rewards))
                    prompt_stds[prompt_key] = torch.std(torch.tensor(hist_rewards))
            else:
                prompt_means[prompt_key] = torch.tensor(0.0)
                prompt_stds[prompt_key] = torch.tensor(1.0)
        
        for i in range(bsz):
            prompt_key = tuple(prompts[i][-250:].cpu().tolist())
            if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - prompt_means[prompt_key]) / (prompt_stds[prompt_key] + epsilon)
            else:
                scores[i] = scores[i] - prompt_means[prompt_key]
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores

class AR3POTrainer(RayPPOTrainer):
    
    def __init__(
        self,
        config,
        tokenizer,
        processor,
        role_worker_mapping: dict[Role, WorkerType],
        resource_pool_manager: ResourcePoolManager,
        ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
        reward_fn=None,
        val_reward_fn=None,
        train_dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        collate_fn=None,
        train_sampler: Optional[Sampler] = None,
        device_name="cuda",
    ):
          
        self.tokenizer = tokenizer
        self.processor = processor
        self.config = config
        self.reward_fn = reward_fn
        self.val_reward_fn = val_reward_fn

        self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
        assert self.hybrid_engine, "Currently, only support hybrid engine"

        if self.hybrid_engine:
            assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"

        self.role_worker_mapping = role_worker_mapping
        self.resource_pool_manager = resource_pool_manager
        self.use_reference_policy = Role.RefPolicy in role_worker_mapping
        self.use_rm = Role.RewardModel in role_worker_mapping
        self.ray_worker_group_cls = ray_worker_group_cls
        self.device_name = device_name
        self.validation_generations_logger = ValidationGenerationsLogger()

        self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0

        if config.algorithm.use_kl_in_reward:
            self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)

        if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
            self.use_critic = True
        elif self.config.algorithm.adv_estimator in [
            AdvantageEstimator.GRPO,
            AdvantageEstimator.GRPO_PASSK,
            AdvantageEstimator.REINFORCE_PLUS_PLUS,
            AdvantageEstimator.REMAX,
            AdvantageEstimator.RLOO,
            AdvantageEstimator.OPO,
            AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
        ]:
            self.use_critic = False
        else:
            raise NotImplementedError

        self.reuse_strategy = config.trainer.reuse_strategy
        self.use_off_policy_training = config.actor_rollout_ref.actor.use_off_policy_training
        self._validate_config()
        self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
        
        self.reward_one_responses = defaultdict(list)
        self.reward_zero_responses = defaultdict(list)
        self.prompt_reward_lists = defaultdict(list)
        self.max_reward_list_size = 1000
        self.storage_dir = os.path.join(self.config.trainer.default_local_dir, "reward_one_responses")
        os.makedirs(self.storage_dir, exist_ok=True)
        
        self.max_stored_responses = self.config.trainer.get("max_stored_responses", 8)
        
        
        
        self._load_reward_one_responses()
        
        self.total_responses_used = 0
        self.use_adaptive_rollout = True
        print("Use historical rewards: ", self.config.algorithm.use_historical_rewards)
        print("Use off policy training: ", self.use_off_policy_training)
        print("Reuse strategy: ", self.reuse_strategy)
        print("Use adaptive rollout: ", self.use_adaptive_rollout)
        print("Max per prompt: ", self.config.trainer.max_per_prompt)
        print("Max stored responses: ", self.config.trainer.max_stored_responses)
        print("Adaptive rollout step: ", self.config.actor_rollout_ref.rollout.n)
        
    def _load_reward_one_responses(self):
        local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt")
        if not os.path.exists(local_latest_checkpointed_iteration):
            print("No latest checkpoint found, starting with empty reward responses.")
            return
            
        with open(local_latest_checkpointed_iteration, "r") as f:
            latest_step = int(f.read().strip())
            
        filepath = os.path.join(self.storage_dir, f"prompt_data_step_{latest_step}.npy")
        if not os.path.exists(filepath):
            print(f"No prompt data found for step {latest_step}, starting with empty storage.")
            return
            
        data = np.load(filepath, allow_pickle=True).item()
        if "total_responses_used" in data:
            self.total_responses_used = data["total_responses_used"]
            print(f"Loaded total_responses_used: {self.total_responses_used}")
            
        for prompt_key, prompt_data in data.items():
            if prompt_key == "total_responses_used":
                continue
            prompt_key = tuple(ast.literal_eval(prompt_key))
            
            if "reward_one_responses" in prompt_data:
                self.reward_one_responses[prompt_key] = [
                    {
                        "prompt_ids": torch.tensor(r["prompt_ids"]),
                        "response_ids": torch.tensor(r["response_ids"]),
                        "old_log_probs": torch.tensor(r["old_log_probs"]),
                        "token_level_scores": torch.tensor(r["token_level_scores"]),
                        "response_mask": torch.tensor(r["response_mask"])
                    }
                    for r in prompt_data["reward_one_responses"]
                ]
             
            if "historical_rewards" in prompt_data:
                self.prompt_reward_lists[prompt_key] = prompt_data["historical_rewards"]
                
        print(f"Successfully loaded prompt data from step {latest_step}.")

    def _track_prompt_responses(self, batch):
        prompts = batch.batch["prompts"]
        token_level_scores = batch.batch["token_level_scores"]
        
        for i in range(len(prompts)):
            prompt = prompts[i]
            reward = token_level_scores[i].sum().item()
            
            prompt_key = tuple(prompt[-250:].cpu().tolist())
            
            
            if prompt_key not in self.prompt_reward_lists:
                self.prompt_reward_lists[prompt_key] = []
            self.prompt_reward_lists[prompt_key].append(reward)
            
            
            if len(self.prompt_reward_lists[prompt_key]) > self.max_reward_list_size:
                self.prompt_reward_lists[prompt_key].pop(0)

    def _store_reward_one_responses(self, batch):
        prompts = batch.batch["prompts"]
        responses = batch.batch["responses"]
        old_log_probs = batch.batch["old_log_probs"]
        token_level_scores = batch.batch["token_level_scores"]
        response_mask = batch.batch["response_mask"]
        
        num_stored_responses = 0
        num_prompts = self.config.data.train_batch_size
        for i in range(len(prompts)):
            prompt = prompts[i]
            response = responses[i]
            reward = token_level_scores[i].sum().item()
            
            prompt_key = tuple(prompt[-250:].cpu().tolist())
            
            response_already_stored = False
            if prompt_key in self.reward_one_responses:
                for stored_response in self.reward_one_responses[prompt_key]:
                    if torch.equal(response.cpu(), stored_response["response_ids"]):
                        response_already_stored = True
                        break
           
            if not response_already_stored:
                response_data = {
                    "prompt_ids": prompt.cpu(),
                    "response_ids": response.cpu(),
                    "old_log_probs": old_log_probs[i].cpu(),
                    "token_level_scores": token_level_scores[i].cpu(),
                    "response_mask": response_mask[i].cpu(),
                }
                if reward > 1 - 1e-3:
                    if prompt_key in self.reward_one_responses and len(self.reward_one_responses[prompt_key]) >= self.max_stored_responses:
                        self.reward_one_responses[prompt_key].pop(0)
                    if prompt_key not in self.reward_one_responses:
                        self.reward_one_responses[prompt_key] = []
                    self.reward_one_responses[prompt_key].append(response_data)
                    num_stored_responses += 1
        
        avg_store_per_prompt = num_stored_responses / num_prompts
        
        print(f"Step {self.global_steps}: Stored {num_stored_responses} responses to {num_prompts} prompts, "
              f"average {avg_store_per_prompt:.2f} responses per prompt")
        
        return {
            "reward_balancing/avg_stored_responses_per_prompt/mean": avg_store_per_prompt,
        }

    def _save_reward_one_responses(self):
        os.makedirs(self.storage_dir, exist_ok=True)
        
        filename = os.path.join(self.storage_dir, f"prompt_data_step_{self.global_steps}.npy")
        
        serializable_data = {}
        num_responses = 0
        for prompt_key in set(list(self.reward_one_responses.keys())):
            prompt_data = {}
            
            if prompt_key in self.reward_one_responses:
                num_responses += len(self.reward_one_responses[prompt_key])
                prompt_data["reward_one_responses"] = [
                    {
                        "prompt_ids": r["prompt_ids"].cpu().numpy(),
                        "response_ids": r["response_ids"].cpu().numpy(),
                        "old_log_probs": r["old_log_probs"].cpu().numpy(),
                        "token_level_scores": r["token_level_scores"].cpu().numpy(),
                        "response_mask": r["response_mask"].cpu().numpy(),
                    }
                    for r in self.reward_one_responses[prompt_key]
                ]
             
            if prompt_key in self.prompt_reward_lists:
                prompt_data["historical_rewards"] = self.prompt_reward_lists[prompt_key]
            
            serializable_data[str(prompt_key)] = prompt_data
        
        serializable_data["total_responses_used"] = self.total_responses_used
        np.save(filename, serializable_data)
        print(f"Saved {num_responses} responses and historical rewards for {len(serializable_data)} prompts to {filename}")
        print(f"Saved total_responses_used: {self.total_responses_used}")

    def _save_checkpoint(self):
        super()._save_checkpoint()

        self._save_reward_one_responses() 

    def _load_checkpoint(self):
        super()._load_checkpoint()
        
        self._load_reward_one_responses()

    def _compute_reward_sum_distribution(self, batch):
        prompts = batch.batch["prompts"]
        rewards = batch.batch["token_level_scores"]
        
        prompt_responses = {}
        for i in range(len(prompts)):
            prompt = prompts[i]
            
            prompt_key = tuple(prompt[-250:].cpu().tolist())
            reward_sum = rewards[i].sum().item()
            
            if prompt_key not in prompt_responses:
                prompt_responses[prompt_key] = []
            
            prompt_responses[prompt_key].append(reward_sum)
        
        reward_sum_counts = {i: 0 for i in range(self.config.trainer.max_per_prompt+1)}
        num_prompts = self.config.data.train_batch_size
        
        for prompt_key, reward_sums in prompt_responses.items():
            total_reward_sum = sum(reward_sums)
            total_reward_sum_rounded = round(total_reward_sum)
            total_reward_sum_rounded = max(0, min(self.config.trainer.max_per_prompt, total_reward_sum_rounded))
            reward_sum_counts[total_reward_sum_rounded] += 1
        
        reward_sum_ratios = {}
        for i in range(self.config.trainer.max_per_prompt+1):
            ratio = reward_sum_counts[i] / num_prompts if num_prompts > 0 else 0
            reward_sum_ratios[f"reward_distribution/prompt_reward_sum_{i}_ratio"] = ratio
        
        return reward_sum_ratios


    def _balance_rewards(self, batch):
        prompts = batch.batch["prompts"]
        rewards = batch.batch["token_level_scores"]
        
        prompt_responses = {}
        for i in range(len(prompts)):
            prompt = prompts[i]
            
            prompt_key = tuple(prompt[-250:].cpu().tolist())
            reward_sum = rewards[i].sum().item()
            
            if prompt_key not in prompt_responses:
                prompt_responses[prompt_key] = {"reward_1": [], "reward_0": []}
            
            if reward_sum > 1 - 1e-3:
                prompt_responses[prompt_key]["reward_1"].append(i)
            else:
                prompt_responses[prompt_key]["reward_0"].append(i)
        
        responses_added_this_step = 0
        total_prompts_this_step = len(prompt_responses)
        
        replaced_indices = []
        replaced_with_reward_one = []
        
        batch_size = len(prompts)
        response_length = batch.batch["responses"].shape[1]
        reuse_mask = torch.zeros((batch_size, response_length), dtype=torch.bool, device=batch.batch["responses"].device)
        
        for prompt_key, responses_dict in prompt_responses.items():
            reward_1_count = len(responses_dict["reward_1"])
            reward_0_count = len(responses_dict["reward_0"])
            
            total = reward_1_count + reward_0_count
            target_good_count = 0
            
            if self.reuse_strategy.startswith("quarter"):
                target_good_count = total // 4
            elif self.reuse_strategy.startswith("AK"):
                target_good_count = 1
            elif self.reuse_strategy.startswith("half"):
                target_good_count = total // 2
                
            if reward_1_count < target_good_count:
                if prompt_key in self.reward_one_responses and self.reward_one_responses[prompt_key]:
                    add_count = target_good_count - reward_1_count
                    add_count = min(add_count, len(self.reward_one_responses[prompt_key]))
                    
                    selected_responses = random.sample(self.reward_one_responses[prompt_key], add_count)

                    for j, selected_response in enumerate(selected_responses):
                        idx = responses_dict["reward_0"][j]
                        
                        current_response_length = batch.batch["responses"].shape[1]
                        selected_response_length = selected_response["response_ids"].shape[0]
                        
                        if selected_response_length < current_response_length:
                            pad_length = current_response_length - selected_response_length
                            
                            pad_token_id = self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None else 0
                            padded_response_ids = torch.cat([
                                selected_response["response_ids"],
                                torch.full((pad_length,), pad_token_id, dtype=selected_response["response_ids"].dtype, device=selected_response["response_ids"].device)
                            ])
                            
                            padded_old_log_probs = torch.cat([
                                selected_response["old_log_probs"],
                                torch.zeros((pad_length,), dtype=selected_response["old_log_probs"].dtype, device=selected_response["old_log_probs"].device)
                            ])
                            
                            padded_token_level_scores = torch.cat([
                                selected_response["token_level_scores"],
                                torch.zeros((pad_length,), dtype=selected_response["token_level_scores"].dtype, device=selected_response["token_level_scores"].device)
                            ])
                            
                            padded_response_mask = torch.cat([
                                selected_response["response_mask"],
                                torch.zeros((pad_length,), dtype=selected_response["response_mask"].dtype, device=selected_response["response_mask"].device)
                            ])
                            
                            batch.batch["responses"][idx] = padded_response_ids
                            batch.batch["old_log_probs"][idx] = padded_old_log_probs
                            batch.batch["token_level_scores"][idx] = padded_token_level_scores
                            batch.batch["token_level_rewards"][idx] = padded_token_level_scores
                            batch.batch["response_mask"][idx] = padded_response_mask
                        else:
                            batch.batch["responses"][idx] = selected_response["response_ids"]
                            batch.batch["old_log_probs"][idx] = selected_response["old_log_probs"]
                            batch.batch["token_level_scores"][idx] = selected_response["token_level_scores"]
                            batch.batch["token_level_rewards"][idx] = selected_response["token_level_scores"]
                            batch.batch["response_mask"][idx] = selected_response["response_mask"]
                        
                        replaced_indices.append(idx)
                        replaced_with_reward_one.append(idx)
                        responses_added_this_step += 1
                        
                        reuse_mask[idx] = True
            
        if replaced_with_reward_one and not self.use_off_policy_training:
            print(f"Step {self.global_steps}: Recomputing log_probs for {len(replaced_with_reward_one)} historical reward=1 responses")
            
            replaced_batch = batch.select_idxs(replaced_with_reward_one)
            
            n_workers = self.resource_pool_manager.get_n_gpus()
            current_replaced_batch_size = len(replaced_batch)
            
            if current_replaced_batch_size % n_workers != 0:
                padding_size = n_workers - (current_replaced_batch_size % n_workers)
                print(f"Step {self.global_steps}: Padding replaced_batch from {current_replaced_batch_size} to {current_replaced_batch_size + padding_size} to be divisible by {n_workers} workers")
                
                replaced_batch.padding(padding_size, padding_candidate="first")
            
            new_log_prob = self.actor_rollout_wg.compute_log_prob(replaced_batch)
            
            for i, idx in enumerate(replaced_with_reward_one):
                batch.batch["old_log_probs"][idx] = new_log_prob.batch["old_log_probs"][i].detach()
        
        avg_responses_per_prompt = responses_added_this_step / total_prompts_this_step if total_prompts_this_step > 0 else 0
        
        print(f"Step {self.global_steps}: Added {responses_added_this_step} responses to {total_prompts_this_step} prompts, "
              f"average {avg_responses_per_prompt:.2f} responses per prompt")
        if replaced_with_reward_one:
            print(f"Step {self.global_steps}: Reused {len(replaced_with_reward_one)} reward=1 responses")
        
        batch.batch["reuse_mask"] = reuse_mask
        
        return {
            "reward_balancing/avg_added_responses_per_prompt/mean": avg_responses_per_prompt,
        }

    def adaptive_rollout(self, gen_batch, max_per_prompt=8, step=2, reward_fn=None, actor_rollout_wg=None, async_rollout_manager=None, async_mode=False, reward_threshold=1-1e-3):
        num_prompts = len(gen_batch)
        finished = [False] * num_prompts
        all_responses = [[] for _ in range(num_prompts)]
        all_rewards = [[] for _ in range(num_prompts)]
        total_generated = [0] * num_prompts
        prompt_indices = list(range(num_prompts))
        n_workers = self.resource_pool_manager.get_n_gpus()
        while True:
            active_indices = [i for i, f in enumerate(finished) if not f and total_generated[i] < max_per_prompt]
            if not active_indices:
                break
            sub_gen_batch = gen_batch.select_idxs(active_indices)
            
            sub_gen_batch_padded, pad_size = pad_dataproto_to_divisor(sub_gen_batch, n_workers)
            sub_output_padded = actor_rollout_wg.generate_sequences(sub_gen_batch_padded)
            sub_output = unpad_dataproto(sub_output_padded, pad_size=pad_size * step)
            print(f"Step {self.global_steps}: Generated {len(sub_output)} responses for {len(active_indices)} prompts")
            assert len(sub_output) == len(active_indices) * step, f"Generated {len(sub_output)} responses for {len(active_indices)} prompts, expected {len(active_indices) * step}"
            sub_rewards, _ = compute_reward(sub_output, reward_fn)
            response_ptr = 0
            for orig_idx in active_indices:
                found_correct = False
                for j in range(step):
                    resp = sub_output[response_ptr]
                    reward = sub_rewards[response_ptr].sum().item() if hasattr(sub_rewards[response_ptr], 'sum') else float(sub_rewards[response_ptr])
                    all_responses[orig_idx].append(resp)
                    all_rewards[orig_idx].append(reward)
                    total_generated[orig_idx] += 1
                    if reward > reward_threshold:
                        found_correct = True
                    response_ptr += 1
                if found_correct:
                    finished[orig_idx] = True
            if all(finished[i] or total_generated[i] >= max_per_prompt for i in range(num_prompts)):
                break
        
        merged = []
        for i in range(num_prompts):
            merged.extend(all_responses[i])
        response_counts = [len(r) for r in all_responses]
        
        from verl.protocol import collate_fn
        merged_data = collate_fn(merged)
        
        original_batch_size = len(merged_data)
        
        if original_batch_size % n_workers != 0:
            keep_size = (original_batch_size // n_workers) * n_workers
            dropped_count = original_batch_size - keep_size
            print(f"Step {self.global_steps}: Dropping {dropped_count} excess data from {original_batch_size} to {keep_size} to be divisible by {n_workers} workers")
            
            merged_data = merged_data.slice(0, keep_size)
            
            for i in range(len(response_counts) - 1, -1, -1):
                if dropped_count <= 0:
                    break
                if response_counts[i] <= dropped_count:
                    dropped_count -= response_counts[i]
                    response_counts[i] = 0
                else:
                    response_counts[i] -= dropped_count
                    dropped_count = 0
            
        
        return merged_data, response_counts

    def fit(self):
        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking
        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0
        self.total_responses_used = 0

        self._load_checkpoint()

        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        self.global_steps += 1
        last_val_metrics = None
        best_val = 0.0

        while self.global_steps <= self.total_training_steps:
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}
                batch: DataProto = DataProto.from_single_dict(batch_dict)

                batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
                non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
                if "multi_modal_data" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("multi_modal_data")
                if "raw_prompt" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("raw_prompt")
                if "tools_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("tools_kwargs")
                
                gen_batch = batch.pop(
                    batch_keys=batch_keys_to_pop,
                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
                )
                if "reward_model" in batch.non_tensor_batch:
                    gen_batch.non_tensor_batch["reward_model"] = batch.non_tensor_batch["reward_model"]
                if "data_source" in batch.non_tensor_batch:
                    gen_batch.non_tensor_batch["data_source"] = batch.non_tensor_batch["data_source"]

                is_last_step = self.global_steps >= self.total_training_steps

                with _timer("step", timing_raw):
                    with _timer("gen", timing_raw):
                        if self.use_adaptive_rollout:
                            gen_batch_output, response_counts = self.adaptive_rollout(
                                gen_batch,
                                max_per_prompt=self.config.trainer.max_per_prompt,
                                step=self.config.actor_rollout_ref.rollout.n,
                                reward_fn=self.reward_fn,
                                actor_rollout_wg=self.actor_rollout_wg,
                                reward_threshold=1-1e-3,
                            )
                        else:
                            if not self.async_rollout_mode:
                                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                            else:
                                self.async_rollout_manager.wake_up()
                                gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
                                self.async_rollout_manager.sleep()
                        
                        self.total_responses_used += len(gen_batch_output)

                    batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
                    if self.use_adaptive_rollout:
                        prompt_ids = gen_batch.batch["input_ids"] if "input_ids" in gen_batch.batch else gen_batch.batch["prompts"]
                        num_prompts = len(gen_batch)
                        response_counts_repeat = []
                        idx = 0
                        for i in range(num_prompts):
                            cnt = 0
                            while idx < len(gen_batch_output):
                                if torch.equal(gen_batch_output.batch["prompts"][idx], prompt_ids[i]):
                                    cnt += 1
                                    idx += 1
                                else:
                                    break
                            response_counts_repeat.append(cnt)
                        assert response_counts == response_counts_repeat, f"response_counts from adaptive_rollout and repeat logic do not match! {response_counts} vs {response_counts_repeat}"
                        batch = batch.sample_level_repeat(response_counts)
                        batch = batch.union(gen_batch_output)
                    else:
                        batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                        batch = batch.union(gen_batch_output)

                    batch.batch["response_mask"] = compute_response_mask(batch)
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    with _timer("old_log_prob", timing_raw):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        entropys = old_log_prob.batch["entropys"]
                        response_masks = batch.batch["response_mask"]
                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                        entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                        old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
                        metrics.update(old_log_prob_metrics)
                        old_log_prob.batch.pop("entropys")
                        batch = batch.union(old_log_prob)

                    with _timer("reward", timing_raw):
                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
                        reward_extra_infos_dict: dict[str, list]
                        batch.batch["token_level_scores"] = reward_tensor

                        print(f"{list(reward_extra_infos_dict.keys())=}")
                        if reward_extra_infos_dict:
                            batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        if self.config.algorithm.use_kl_in_reward:
                            batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)
                        else:
                            batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                    self._track_prompt_responses(batch)
                    
                    reward_balancing_metrics = self._balance_rewards(batch)
                    metrics.update(reward_balancing_metrics)
                    
                    
                    if self.config.trainer.reuse_strategy != "none":
                        response_store_metrics = self._store_reward_one_responses(batch)
                        metrics.update(response_store_metrics)
                    
                    with _timer("adv", timing_raw):
                        
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        use_historical_rewards = self.config.algorithm.get("use_historical_rewards", False)
                        if use_historical_rewards:
                            batch = compute_advantage_with_historical_rewards(
                                batch,
                                adv_estimator=self.config.algorithm.adv_estimator,
                                norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                                historical_rewards=self.prompt_reward_lists,
                            )
                        else:
                            batch = compute_advantage(
                                batch,
                                adv_estimator=self.config.algorithm.adv_estimator,
                                gamma=self.config.algorithm.gamma,
                                lam=self.config.algorithm.lam,
                                num_repeat=self.config.actor_rollout_ref.rollout.n,
                                norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                                multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable,
                                use_pf_ppo=self.config.algorithm.use_pf_ppo,
                                pf_ppo_reweight_method=self.config.algorithm.pf_ppo.reweight_method,
                                pf_ppo_weight_pow=self.config.algorithm.pf_ppo.weight_pow,
                            )

                    
                    reward_distribution_metrics = self._compute_reward_sum_distribution(batch)
                    metrics.update(reward_distribution_metrics)

                    
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        
                        with _timer("update_actor", timing_raw):
                            batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)
                    
                    if self.global_steps % self.config.trainer.save_freq == 0:
                        avg_responses_per_step = self.total_responses_used / (self.global_steps)
                        print(f"Average responses per step: {avg_responses_per_step:.2f}")
                        avg_responses_per_prompt = self.total_responses_used / (self.global_steps * self.config.data.train_batch_size)
                        print(f"Average responses per prompt: {avg_responses_per_prompt:.2f}")
                    
                    val_metrics = None
                    save_flag = False
                    
                    is_test_step = (self.config.trainer.test_freq > 0 and 
                                   (is_last_step or self.global_steps % self.config.trainer.test_freq == 0))
                    is_save_step = (self.config.trainer.save_freq > 0 and 
                                   self.global_steps % self.config.trainer.save_freq == 0)
                    
                    if is_test_step:
                        
                        with _timer("testing", timing_raw):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)
                        val_acc = float(val_metrics["val-core/Average/reward/mean@1"])
                        
                        if val_acc >= best_val:
                            print(f"Step {self.global_steps}: Saving checkpoint due to improved validation performance: {val_acc:.4f} > {best_val:.4f}")
                            with _timer("save_checkpoint", timing_raw):
                                self._save_checkpoint()
                            save_flag = True
                            best_val = val_acc
                    
                    if is_save_step:
                        
                        if val_metrics is None:
                            with _timer("testing", timing_raw):
                                val_metrics: dict = self._validate()
                                if is_last_step:
                                    last_val_metrics = val_metrics
                            metrics.update(val_metrics)
                            val_acc = float(val_metrics["val-core/Average/reward/mean@1"])
                            best_val = max(best_val, val_acc)
                        
                        if not save_flag:
                            print(f"Step {self.global_steps}: Saving checkpoint due to save frequency")
                            with _timer("save_checkpoint", timing_raw):
                                self._save_checkpoint()

                    
                
                metrics.update(
                    {
                        "training/global_step": self.global_steps,
                        "training/num_responses_this_step": float(batch.batch["responses"].shape[0] / self.config.data.train_batch_size),
                    }
                )
                
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

                
                logger.log(data=metrics, step=self.global_steps)

                progress_bar.update(1)
                self.global_steps += 1
                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    print(f"Average responses per step: {avg_responses_per_step:.2f}")
                    print(f"Average responses per prompt: {avg_responses_per_prompt:.2f}")
                    progress_bar.close()
                    return
        
        
        