import os
import sys
import importlib.util

import numpy as np
import torch
from tensordict import TensorDict

from verl import DataProto

from verl.trainer.ppo.core_algos import (
    register_policy_loss,
    agg_loss,
)
import verl.utils.torch_functional as verl_F


def _import_dapo_trainer():
    verl_dir = None
    for path in sys.path:
        candidate = os.path.join(path, "recipe", "dapo", "dapo_ray_trainer.py")
        if os.path.exists(candidate):
            verl_dir = path
            break

    if verl_dir is None:
        this_dir = os.path.dirname(os.path.abspath(__file__))
        candidate = os.path.join(this_dir, "..", "verl", "recipe", "dapo", "dapo_ray_trainer.py")
        if os.path.exists(candidate):
            verl_dir = os.path.join(this_dir, "..", "verl")

    if verl_dir is None:
        raise ImportError("Cannot find verl/recipe/dapo/dapo_ray_trainer.py")

    module_path = os.path.join(verl_dir, "recipe", "dapo", "dapo_ray_trainer.py")
    spec = importlib.util.spec_from_file_location("dapo_ray_trainer", module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module.RayDAPOTrainer

RayDAPOTrainer = _import_dapo_trainer()

from .external_api_generator import ExternalAPIGenerator


class RayDAPOTrainerOnestage(RayDAPOTrainer):

    def __init__(
        self,
        config,
        tokenizer,
        processor,
        role_worker_mapping,
        resource_pool_manager,
        ray_worker_group_cls,
        reward_fn=None,
        val_reward_fn=None,
    ):
        super().__init__(
            config=config,
            tokenizer=tokenizer,
            processor=processor,
            role_worker_mapping=role_worker_mapping,
            resource_pool_manager=resource_pool_manager,
            ray_worker_group_cls=ray_worker_group_cls,
            reward_fn=reward_fn,
            val_reward_fn=val_reward_fn,
        )

        from verl.trainer.ppo.reward import compute_reward
        self.compute_reward = compute_reward

        external_api_config = config.get("external_api", {})
        self.external_api_enabled = external_api_config.get("enable", True)

        if self.external_api_enabled:
            train_files = None
            if hasattr(config, 'data') and hasattr(config.data, 'train_files'):
                train_files = config.data.train_files

            self.external_generator = ExternalAPIGenerator(
                api_url=external_api_config.get("url", "https://api.zyai.online/v1"),
                api_key=external_api_config.get("key", os.environ.get("EXTERNAL_API_KEY", "")),
                model_name=external_api_config.get("model", "gpt-5-chat"),
                max_workers=external_api_config.get("max_workers", 32),
                timeout=external_api_config.get("timeout", 120),
                max_retries=external_api_config.get("max_retries", 3),
                temperature=external_api_config.get("temperature", 0.7),
                max_tokens=external_api_config.get("max_tokens", 10000),
                debug=external_api_config.get("debug", False),
                train_files=train_files,
            )

            policy_loss_config = config.actor_rollout_ref.actor.get("policy_loss", {})
            self.external_min_log_prob = policy_loss_config.get("external_min_log_prob", None) if policy_loss_config else None
            if self.external_min_log_prob > 0:
                self.external_min_log_prob = -self.external_min_log_prob
        else:
            self.external_generator = None
            self.external_min_log_prob = None

    def _replace_last_responses_with_external(
        self,
        gen_batch_output: DataProto,
        gen_batch: DataProto,
        n_per_prompt: int,
    ) -> DataProto:
        batch_size = len(gen_batch)

        raw_prompts = gen_batch.non_tensor_batch.get("raw_prompt", None)
        if raw_prompts is None:
            return gen_batch_output

        if isinstance(raw_prompts, np.ndarray):
            raw_prompts = raw_prompts.tolist()

        external_results = self.external_generator.generate_batch(
            prompts=raw_prompts,
            tokenizer=self.tokenizer,
            max_response_length=self.config.data.max_response_length,
        )

        success_count = sum(1 for r in external_results if r['success'])

        pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        device = gen_batch_output.batch['input_ids'].device

        prompt_length = gen_batch.batch['input_ids'].shape[1]

        if 'prompts' not in gen_batch_output.batch:
            gen_batch_output.batch['prompts'] = gen_batch_output.batch['input_ids'][:, :prompt_length].clone()

        if 'responses' not in gen_batch_output.batch:
            gen_batch_output.batch['responses'] = gen_batch_output.batch['input_ids'][:, prompt_length:].clone()

        actual_response_length = gen_batch_output.batch['responses'].shape[1]
        total_seq_length = gen_batch_output.batch['input_ids'].shape[1]

        last_indices = [i * n_per_prompt + (n_per_prompt - 1) for i in range(batch_size)]

        for i, (last_idx, result) in enumerate(zip(last_indices, external_results)):
            batch_prompt_ids = gen_batch_output.batch['input_ids'][last_idx, :prompt_length]
            expected_prompt_ids = self.tokenizer.apply_chat_template(raw_prompts[i], add_generation_prompt=True, return_tensors='pt')[0]
            assert torch.equal(batch_prompt_ids[-len(expected_prompt_ids):].cpu(), expected_prompt_ids[-len(batch_prompt_ids):]), \
                f"Prompt mismatch at i={i}, last_idx={last_idx}"

            if not result['success']:
                continue

            response_ids = result['response_ids']
            response_ids = response_ids.to(device)
            response_len = len(response_ids)

            if response_len < actual_response_length:
                padding = torch.full(
                    (actual_response_length - response_len,),
                    pad_token_id,
                    dtype=torch.long,
                    device=device,
                )
                response_ids_padded = torch.cat([response_ids, padding])
                response_mask = torch.cat([
                    torch.ones(response_len, dtype=torch.long, device=device),
                    torch.zeros(actual_response_length - response_len, dtype=torch.long, device=device),
                ])
            else:
                response_ids_padded = response_ids[:actual_response_length]
                response_mask = torch.ones(actual_response_length, dtype=torch.long, device=device)

            prompt_ids = gen_batch_output.batch['input_ids'][last_idx, :prompt_length]
            prompt_mask = gen_batch_output.batch['attention_mask'][last_idx, :prompt_length]

            new_input_ids = torch.cat([prompt_ids, response_ids_padded])
            new_attention_mask = torch.cat([prompt_mask, response_mask])

            expected_seq_len = gen_batch_output.batch['input_ids'].shape[1]

            gen_batch_output.batch['input_ids'][last_idx].copy_(new_input_ids)
            gen_batch_output.batch['attention_mask'][last_idx].copy_(new_attention_mask)

        model_idx = None
        for idx in range(len(gen_batch_output.batch['input_ids'])):
            if idx not in last_indices:
                model_idx = idx
                break

        if model_idx is not None:
            model_prompt_ids = gen_batch_output.batch['input_ids'][model_idx, :prompt_length]
            model_full_ids = gen_batch_output.batch['input_ids'][model_idx]
            decoded_model_full = self.tokenizer.decode(model_full_ids, skip_special_tokens=False).replace('<|endoftext|>','')
            decoded_model_prompt = self.tokenizer.decode(model_prompt_ids, skip_special_tokens=False).replace('<|endoftext|>','')

        source_markers = []
        for i in range(batch_size):
            for j in range(n_per_prompt):
                if j == n_per_prompt - 1 and external_results[i]['success']:
                    source_markers.append('external')
                else:
                    source_markers.append('model')
        gen_batch_output.non_tensor_batch['response_source'] = np.array(source_markers, dtype=object)

        fields_to_clear = ['rm_scores', 'old_log_probs', 'rollout_log_probs']

        for key in fields_to_clear:
            if key in gen_batch_output.batch:
                del gen_batch_output.batch[key]

        reward_keys_to_clear = ['acc', 'score', 'reward_score']
        for key in reward_keys_to_clear:
            if key in gen_batch_output.non_tensor_batch:
                del gen_batch_output.non_tensor_batch[key]
        return gen_batch_output

    def _get_external_responses(
        self,
        gen_batch_output: DataProto,
        gen_batch: DataProto,
    ):
        if not self.external_api_enabled or self.external_generator is None:
            return None

        batch_size = len(gen_batch)

        raw_prompts = gen_batch.non_tensor_batch.get("raw_prompt", None)
        if raw_prompts is None:
            return None

        if isinstance(raw_prompts, np.ndarray):
            raw_prompts = raw_prompts.tolist()

        external_results = self.external_generator.generate_batch(
            prompts=raw_prompts,
            tokenizer=self.tokenizer,
            max_response_length=self.config.data.max_response_length,
        )

        success_count = sum(1 for r in external_results if r['success'])

        return external_results

    def _selective_inject_external_responses(
        self,
        batch: DataProto,
        external_results: list,
        gen_batch: DataProto,
        n_per_prompt: int,
    ) -> tuple:

        external_metrics = {}

        if not self.external_api_enabled or external_results is None:
            return batch, external_metrics

        batch_size = len(gen_batch)

        token_level_rewards = batch.batch["token_level_rewards"]
        response_mask = batch.batch.get("response_mask", None)

        rewards = token_level_rewards.sum(dim=-1).cpu().numpy()

        uids = batch.non_tensor_batch["uid"]

        from collections import defaultdict
        uid2model_info = defaultdict(list)
        for idx, (uid, reward) in enumerate(zip(uids, rewards)):
            uid2model_info[uid].append((idx, reward))

        pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        device = batch.batch['input_ids'].device

        prompt_length = gen_batch.batch['input_ids'].shape[1]
        if 'prompts' not in batch.batch:
            batch.batch['prompts'] = batch.batch['input_ids'][:, :prompt_length].clone()
        if 'responses' not in batch.batch:
            batch.batch['responses'] = batch.batch['input_ids'][:, prompt_length:].clone()

        actual_response_length = batch.batch['responses'].shape[1]

        external_batch_list = []
        external_to_prompt_idx = []

        for i, result in enumerate(external_results):
            if not result['success']:
                continue

            response_ids = result['response_ids'].to(device)
            response_len = len(response_ids)

            if response_len < actual_response_length:
                padding = torch.full(
                    (actual_response_length - response_len,),
                    pad_token_id,
                    dtype=torch.long,
                    device=device,
                )
                response_ids_padded = torch.cat([response_ids, padding])
                response_mask_ext = torch.cat([
                    torch.ones(response_len, dtype=torch.long, device=device),
                    torch.zeros(actual_response_length - response_len, dtype=torch.long, device=device),
                ])
            else:
                response_ids_padded = response_ids[:actual_response_length]
                response_mask_ext = torch.ones(actual_response_length, dtype=torch.long, device=device)

            prompt_ids = gen_batch.batch['input_ids'][i, :prompt_length]
            prompt_mask = gen_batch.batch['attention_mask'][i, :prompt_length]

            new_input_ids = torch.cat([prompt_ids, response_ids_padded])
            new_attention_mask = torch.cat([prompt_mask, response_mask_ext])

            sample_dict = {
                'input_ids': new_input_ids.unsqueeze(0),
                'attention_mask': new_attention_mask.unsqueeze(0),
                'prompts': prompt_ids.unsqueeze(0),
                'responses': response_ids_padded.unsqueeze(0),
                'response_mask': response_mask_ext.unsqueeze(0),
            }
            if 'position_ids' in batch.batch:
                new_position_ids = new_attention_mask.cumsum(dim=-1) - 1
                new_position_ids = new_position_ids.clamp(min=0)
                sample_dict['position_ids'] = new_position_ids.unsqueeze(0)

            external_batch_list.append(sample_dict)
            external_to_prompt_idx.append(i)

        if len(external_batch_list) == 0:
            source_markers = ['model'] * len(batch)
            batch.non_tensor_batch['response_source'] = np.array(source_markers, dtype=object)
            return batch, external_metrics

        external_batch_dict = {}
        for key in external_batch_list[0].keys():
            external_batch_dict[key] = torch.cat([sample[key] for sample in external_batch_list], dim=0)

        external_batch_tensordict = TensorDict(external_batch_dict, batch_size=len(external_batch_list))

        external_non_tensor = {}

        for key in ['raw_prompt', 'data_source']:
            if key in gen_batch.non_tensor_batch:
                external_non_tensor[key] = np.array([gen_batch.non_tensor_batch[key][i] for i in external_to_prompt_idx], dtype=object)

        if 'reward_model' in gen_batch.non_tensor_batch:
            external_non_tensor['reward_model'] = np.array([gen_batch.non_tensor_batch['reward_model'][i] for i in external_to_prompt_idx], dtype=object)

        external_batch_proto = DataProto(batch=external_batch_tensordict, non_tensor_batch=external_non_tensor)

        external_reward_tensor, external_reward_extra_infos = self.compute_reward(external_batch_proto, self.reward_fn)

        external_rewards = external_reward_tensor.sum(dim=-1).cpu().numpy()

        replaced_count = 0
        skipped_count = 0
        total_success_external = len(external_batch_list)

        last_indices = [i * n_per_prompt + (n_per_prompt - 1) for i in range(batch_size)]

        replacement_debug_info = []

        replaced_indices = set()

        for ext_idx, prompt_i in enumerate(external_to_prompt_idx):
            last_idx = last_indices[prompt_i]
            uid = uids[last_idx]
            external_reward = external_rewards[ext_idx]

            model_info = [(idx, r) for idx, r in uid2model_info.get(uid, []) if idx != last_idx]
            if not model_info:
                skipped_count += 1
                replacement_debug_info.append({
                    'prompt_idx': prompt_i,
                    'batch_idx': last_idx,
                    'uid': uid,
                    'external_reward': external_reward,
                    'min_model_reward': None,
                    'replaced': False,
                    'reason': 'no model info found'
                })
                continue

            model_rewards = [r for _, r in model_info]

            replace_threshold = 0.01

            min_model_reward = max(model_rewards)
            if external_reward < (min_model_reward + replace_threshold):
                skipped_count += 1
                replacement_debug_info.append({
                    'prompt_idx': prompt_i,
                    'batch_idx': last_idx,
                    'uid': uid,
                    'external_reward': external_reward,
                    'min_model_reward': min_model_reward,
                    'model_rewards': model_rewards,
                    'replaced': False,
                    'reason': 'external_reward < min_model_reward'
                })
                continue

            new_input_ids = external_batch_tensordict['input_ids'][ext_idx]
            new_attention_mask = external_batch_tensordict['attention_mask'][ext_idx]
            response_ids_padded = external_batch_tensordict['responses'][ext_idx]
            response_mask_ext = external_batch_tensordict['response_mask'][ext_idx]

            batch_prompt_ids = batch.batch['input_ids'][last_idx, :prompt_length]
            expected_prompt_ids = self.tokenizer.apply_chat_template(
                gen_batch.non_tensor_batch['raw_prompt'][prompt_i],
                add_generation_prompt=True,
                return_tensors='pt'
            )[0]
            assert torch.equal(batch_prompt_ids[-len(expected_prompt_ids):].cpu(), expected_prompt_ids[-len(batch_prompt_ids):]), \
                f"Prompt mismatch at prompt_i={prompt_i}, last_idx={last_idx}"

            batch.batch['input_ids'][last_idx].copy_(new_input_ids)
            batch.batch['attention_mask'][last_idx].copy_(new_attention_mask)

            if 'responses' in batch.batch:
                batch.batch['responses'][last_idx].copy_(response_ids_padded)
                batch.batch['response_mask'][last_idx].copy_(response_mask_ext)

            if 'position_ids' in batch.batch:
                new_position_ids = external_batch_tensordict['position_ids'][ext_idx]
                batch.batch['position_ids'][last_idx].copy_(new_position_ids)

            batch.batch['token_level_rewards'][last_idx].copy_(external_reward_tensor[ext_idx])
            batch.batch['token_level_scores'][last_idx].copy_(external_reward_tensor[ext_idx])

            if external_reward_extra_infos:
                for key, values in external_reward_extra_infos.items():
                    if key in batch.non_tensor_batch:
                        batch.non_tensor_batch[key][last_idx] = values[ext_idx]

            replaced_count += 1
            replaced_indices.add(last_idx)

            replacement_debug_info.append({
                'prompt_idx': prompt_i,
                'batch_idx': last_idx,
                'uid': uid,
                'external_reward': external_reward,
                'min_model_reward': min_model_reward,
                'model_rewards': model_rewards,
                'replaced': True,
                'input_ids': new_input_ids[(prompt_length-20):].cpu().tolist()[:100],
                'response_ids': response_ids_padded.cpu().tolist()[:60],
                'response_mask': response_mask_ext[(prompt_length-20):].cpu().tolist()[:100],
                'reward_tensor': external_reward_tensor[ext_idx].cpu().tolist()[:50],
            })

        source_markers = []
        for i in range(len(batch)):
            if i in replaced_indices:
                source_markers.append('external')
            else:
                source_markers.append('model')
        batch.non_tensor_batch['response_source'] = np.array(source_markers, dtype=object)

        external_metrics = {
            "external/total_success": total_success_external,
            "external/replaced_count": replaced_count,
            "external/skipped_count": skipped_count,
            "external/replacement_ratio": replaced_count / max(1, total_success_external),
        }

        fields_to_clear = [ 'old_log_probs', 'rollout_log_probs']
        for key in fields_to_clear:
            if key in batch.batch:
                del batch.batch[key]

        return batch, external_metrics

    def _apply_log_prob_floor_to_external(
        self,
        batch: DataProto,
        min_log_prob: float = None,
    ) -> DataProto:
        if min_log_prob is None:
            min_log_prob = self.external_min_log_prob

        if min_log_prob is None:
            return batch

        response_sources = batch.non_tensor_batch.get("response_source", None)
        if response_sources is None:
            return batch

        external_indices = [i for i, s in enumerate(response_sources) if s == 'external']
        if len(external_indices) == 0:
            return batch

        if 'old_log_probs' in batch.batch:
            old_log_probs = batch.batch['old_log_probs']
            response_mask = batch.batch.get('response_mask', None)

            if response_mask is not None:
                masked_log_probs = old_log_probs[external_indices] * response_mask[external_indices]
                token_counts = response_mask[external_indices].sum(dim=-1)
                before_avg_per_response = masked_log_probs.sum(dim=-1) / token_counts.clamp(min=1)
            else:
                before_avg_per_response = old_log_probs[external_indices].mean(dim=-1)

            before_min = torch.min(before_avg_per_response).item()
            before_mean = torch.mean(before_avg_per_response).item()

            floor_mask = torch.zeros_like(old_log_probs, dtype=torch.bool)

            for idx in external_indices:
                if response_mask is not None:
                    floor_mask[idx] = (old_log_probs[idx] < min_log_prob) & (response_mask[idx] > 0)
                else:
                    floor_mask[idx] = old_log_probs[idx] < min_log_prob
                if response_mask is not None:
                    old_log_probs[idx] = torch.where(
                        response_mask[idx] > 0,
                        torch.clamp(old_log_probs[idx], min=min_log_prob),
                        old_log_probs[idx]
                    )
                else:
                    old_log_probs[idx] = torch.clamp(old_log_probs[idx], min=min_log_prob)

            batch.batch['old_log_prob_floor_mask'] = floor_mask

            floored_count = floor_mask.sum().item()
            total_external_tokens = sum(response_mask[idx].sum().item() if response_mask is not None
                                        else old_log_probs[idx].numel()
                                        for idx in external_indices)

            if response_mask is not None:
                masked_log_probs_after = old_log_probs[external_indices] * response_mask[external_indices]
                token_counts = response_mask[external_indices].sum(dim=-1)
                after_avg_per_response = masked_log_probs_after.sum(dim=-1) / token_counts.clamp(min=1)
            else:
                after_avg_per_response = old_log_probs[external_indices].mean(dim=-1)

            after_min = torch.min(after_avg_per_response).item()
            after_mean = torch.mean(after_avg_per_response).item()

        return batch

    def fit(self):
        import uuid
        from collections import defaultdict
        from copy import deepcopy

        import torch
        from omegaconf import OmegaConf
        from tqdm import tqdm

        from verl import DataProto
        from verl.trainer.ppo.ray_trainer import (
            AdvantageEstimator,
            apply_kl_penalty,
            compute_advantage,
        )
        from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics
        from verl.utils.metric import reduce_metrics
        from verl.utils.profiler import marked_timer
        from verl.utils.rollout_skip import RolloutSkip
        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0
        self.gen_steps = 0
        self._load_checkpoint()

        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
            rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
            rollout_skip.wrap_generate_sequences()

        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        self.global_steps += 1
        self.gen_steps += 1
        last_val_metrics = None

        prev_step_profile = False
        curr_step_profile = (
            self.global_steps in self.config.global_profiler.steps
            if self.config.global_profiler.steps is not None
            else False
        )
        next_step_profile = False

        timing_raw = defaultdict(float)
        batch = None
        num_prompt_in_batch = 0
        num_gen_batches = 0

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
                    self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False)

                warmup_cfg = getattr(self.config.trainer, 'warmup', None)
                if warmup_cfg is not None and getattr(warmup_cfg, 'steps', 0) > 0:
                    warmup_steps = warmup_cfg.steps
                    if self.global_steps <= warmup_steps:
                        if self.global_steps == 1:
                            self._original_clip_ratio_low = self.config.actor_rollout_ref.actor.clip_ratio_low
                            self.config.actor_rollout_ref.actor.clip_ratio_low = warmup_cfg.clip_ratio_low
                    elif self.global_steps == warmup_steps + 1:
                        self.config.actor_rollout_ref.actor.clip_ratio_low = self._original_clip_ratio_low

                metrics = {}

                with marked_timer("start_profile", timing_raw):
                    self._start_profiling(
                        not prev_step_profile and curr_step_profile
                        if self.config.global_profiler.profile_continuous_steps
                        else curr_step_profile
                    )

                new_batch: DataProto = DataProto.from_single_dict(batch_dict)
                num_gen_batches += 1
                gen_batch = self._get_gen_batch(new_batch)
                gen_batch_output = gen_batch.repeat(
                    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
                )

                is_last_step = self.global_steps >= self.total_training_steps

                with marked_timer("step", timing_raw):
                    with marked_timer("gen", timing_raw, "red"):
                        gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
                        timing_raw.update(gen_batch_output.meta_info.get("timing", {}))
                        gen_batch_output.meta_info.pop("timing", None)

                        external_results = None
                        with marked_timer("gen_external", timing_raw, "purple"):
                            external_results = self._get_external_responses(gen_batch_output, gen_batch)

                    new_batch.non_tensor_batch["uid"] = np.array(
                        [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
                    )
                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    new_batch = new_batch.union(gen_batch_output)

                    with marked_timer("reward", timing_raw, "yellow"):
                        if self.use_rm and "rm_scores" not in new_batch.batch.keys():
                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)
                            new_batch = new_batch.union(reward_tensor)
                        reward_tensor, reward_extra_infos_dict = self.compute_reward(new_batch, self.reward_fn)
                        new_batch.batch["token_level_scores"] = reward_tensor
                        if reward_extra_infos_dict:
                            new_batch.non_tensor_batch.update(
                                {k: np.array(v) for k, v in reward_extra_infos_dict.items()}
                            )
                        if self.config.algorithm.use_kl_in_reward:
                            new_batch, kl_metrics = apply_kl_penalty(
                                new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                            )
                            metrics.update(kl_metrics)
                        else:
                            new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

                        new_batch, external_metrics = self._selective_inject_external_responses(
                            batch=new_batch,
                            external_results=external_results,
                            gen_batch=gen_batch,
                            n_per_prompt=self.config.actor_rollout_ref.rollout.n,
                        )
                        metrics.update(external_metrics)

                        if 'response_source' in new_batch.non_tensor_batch:
                            response_sources = new_batch.non_tensor_batch['response_source']
                            external_indices = [i for i, s in enumerate(response_sources) if s == 'external']

                            if len(external_indices) > 0:
                                ext_idx = external_indices[-1]
                                prompt_length = gen_batch.batch['input_ids'].shape[1]

                                ext_attention_mask = new_batch.batch['attention_mask'][ext_idx]
                                ext_response_mask = new_batch.batch['response_mask'][ext_idx]

                                attention_response_sum = ext_attention_mask[prompt_length:].sum().item()
                                response_mask_sum = ext_response_mask.sum().item()

                                assert attention_response_sum == response_mask_sum, \
                                    f"Mask mismatch! attention_mask[{prompt_length}:].sum()={attention_response_sum} != response_mask.sum()={response_mask_sum}"

                    if not self.config.algorithm.filter_groups.enable:
                        batch = new_batch
                    else:
                        metric_name = self.config.algorithm.filter_groups.metric
                        if metric_name == "seq_final_reward":
                            new_batch.non_tensor_batch["seq_final_reward"] = (
                                new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
                            )
                        elif metric_name == "seq_reward":
                            new_batch.non_tensor_batch["seq_reward"] = (
                                new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
                            )

                        prompt_uid2metric_vals = defaultdict(list)
                        prompt_uid2source_and_metric = defaultdict(list)

                        response_sources = new_batch.non_tensor_batch.get("response_source", None)

                        for idx, (uid, metric_val) in enumerate(zip(
                            new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
                        )):
                            prompt_uid2metric_vals[uid].append(metric_val)
                            if response_sources is not None:
                                prompt_uid2source_and_metric[uid].append((response_sources[idx], metric_val))

                        prompt_uid2metric_std = {}
                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

                        kept_prompt_uids = [
                            uid for uid, std in prompt_uid2metric_std.items()
                            if (std > 0.001 or len(prompt_uid2metric_vals[uid]) == 1)
                        ]

                        num_prompt_in_batch += len(kept_prompt_uids)

                        kept_traj_idxs = [
                            idx for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"])
                            if traj_from_prompt_uid in kept_prompt_uids
                        ]

                        new_batch = new_batch[kept_traj_idxs]
                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])

                        prompt_bsz = self.config.data.train_batch_size
                        if num_prompt_in_batch < prompt_bsz:
                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                                self.gen_steps += 1
                                continue
                            else:
                                raise ValueError(
                                    f"{num_gen_batches=} >= {max_num_gen_batches=}. "
                                    "Generated too many. Please check if your data are too difficult."
                                )
                        else:
                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
                            batch = batch[:traj_bsz]

                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    assert not self.config.algorithm.use_kl_in_reward
                    if not self.config.algorithm.use_kl_in_reward:
                        old_log_probs_exist_before = 'old_log_probs' in batch.batch
                        rollout_log_probs_exist = 'rollout_log_probs' in batch.batch

                        batch = self.compute_kl_related_metrics(batch, metrics, timing_raw)

                        batch = self._apply_log_prob_floor_to_external(batch)

                    if self.use_critic:
                        with marked_timer("values", timing_raw, "cyan"):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with marked_timer("adv", timing_raw, "brown"):
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                        )

                    if self.use_critic:
                        with marked_timer("update_critic", timing_raw, "pink"):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    if self.config.trainer.critic_warmup <= self.global_steps:
                        with marked_timer("update_actor", timing_raw, "red"):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
                    if rollout_data_dir:
                        self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)

                if (
                    self.val_reward_fn is not None
                    and self.config.trainer.test_freq > 0
                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
                ):
                    with marked_timer("testing", timing_raw, "green"):
                        val_metrics: dict = self._validate()
                        if is_last_step:
                            last_val_metrics = val_metrics
                    metrics.update(val_metrics)

                if self.config.trainer.save_freq > 0 and (
                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0
                ):
                    with marked_timer("save_checkpoint", timing_raw, "green"):
                        self._save_checkpoint()

                with marked_timer("stop_profile", timing_raw):
                    next_step_profile = (
                        self.global_steps + 1 in self.config.global_profiler.steps
                        if self.config.global_profiler.steps is not None
                        else False
                    )
                    self._stop_profiling(
                        curr_step_profile and not next_step_profile
                        if self.config.global_profiler.profile_continuous_steps
                        else curr_step_profile
                    )
                    prev_step_profile = curr_step_profile
                    curr_step_profile = next_step_profile

                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                timing_raw = defaultdict(float)

                metrics["train/num_gen_batches"] = num_gen_batches
                batch = None
                num_prompt_in_batch = 0
                num_gen_batches = 0

                logger.log(data=metrics, step=self.global_steps)

                if is_last_step:
                    if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
                        self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True)
                    progress_bar.close()
                    return

                progress_bar.update(1)
                self.global_steps += 1
                self.gen_steps += 1

        import os
        checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
        if not os.path.exists(checkpoint_dir):
            timing_raw = defaultdict(float)
            with marked_timer("save_checkpoint", timing_raw, "green"):
                self._save_checkpoint()
            metrics = {f"timing/{k}": v for k, v in timing_raw.items()}
            logger.log(data=metrics, step=self.global_steps)
