# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint

import numpy as np
import torch
from tqdm import tqdm
import hashlib

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import (
    compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
    reduce_metrics,
    compute_utils_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.utils.debug import marked_timer
import random
import json
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto

def compute_encourage_advantage(data: DataProto, offset=None, reactivate_all_correct=False, filter_wrong_in_resp=False, wr_corr_max_ratio=2, multi_turn=False, norm_adv_by_std_in_grpo=True, config=None):
    """Compute advantage estimates for policy optimization.

    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.
    The advantage estimates are used to guide policy optimization in RL algorithms.

    Args:
        data (DataProto): The data containing batched model outputs and inputs.
        adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).
        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.
        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.
        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.
        multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.
        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True.
        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.

    Returns:
        DataProto: The updated data with computed advantages and returns.
    """

    # filter_wrong_in_resp_v2
    # if filter_wrong_in_resp:
    #     kept_index = []
    #     id2score = defaultdict(list)
    #     id2index = defaultdict(list)
    #     scores = data.batch["token_level_rewards"].sum(dim=-1)
    #     index = data.non_tensor_batch["uid"]
    #     epsilon = 1e-6
    #     bsz = scores.shape[0]
    #     for i in range(bsz):
    #         id2score[index[i]].append(scores[i].item())
    #         id2index[index[i]].append(i)
    #     for id in id2index:
    #         cur_score = np.array(id2score[id])
    #         cur_index = np.array(id2index[id])
    #         correct_index = cur_index[cur_score == 1]
    #         wrong_index = cur_index[cur_score != 1]
    #         rollout_n = len(cur_index)
    #         num_correct = len(correct_index)
    #         if (rollout_n - num_correct) / (num_correct + epsilon) > wr_corr_max_ratio:
    #             kept_index.extend(correct_index.tolist())
    #             kept_index.extend(wrong_index[:num_correct * wr_corr_max_ratio].tolist())
    #         else:
    #             kept_index.extend(cur_index.tolist())
    #     data = data[kept_index]

    # Back-compatible with trainers that do not compute response mask in fit
    if "response_mask" not in data.batch.keys():
        data.batch["response_mask"] = compute_response_mask(data)
    
    # Initialize the mask for GRPO calculation
    grpo_calculation_mask = data.batch["response_mask"]
    if multi_turn:
        # If multi-turn, replace the mask with the relevant part of loss_mask
        # Get length from the initial response mask
        response_length = grpo_calculation_mask.size(1)
        # This mask is the one intended for GRPO
        grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:]
    # Call compute_grpo_outcome_advantage with parameters matching its definition
    # advantages, returns = core_algos.compute_grpo_outcome_advantage(
    #     token_level_rewards=data.batch["token_level_rewards"],
    #     response_mask=grpo_calculation_mask,
    #     index=data.non_tensor_batch["uid"],
    #     norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
    # )
    scores = data.batch["token_level_rewards"].sum(dim=-1)
    response_mask=grpo_calculation_mask
    index = data.non_tensor_batch["uid"]
    epsilon = 1e-6
    
    id2score = defaultdict(list)
    id2index = defaultdict(list)
    id2mean = {}
    id2std = {}

    # scores = torch.tensor([random.choice([-1.0, 1.0]) for _ in range(scores.shape[0])])

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i].item())
            id2index[index[i]].append(i)
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
                if reactivate_all_correct and id2mean[idx] == 1:
                    # reactivate data with all correct responses. Simulate 1 negative respones in the batch to compute advantage.
                    id2score[idx].append(-1.0)
                    id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                    id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
                    id2score[idx].pop()
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                scores[i] = scores[i] - id2mean[index[i]]
        scores = scores.unsqueeze(-1) * response_mask

    if offset is not None:
        # scores = torch.where(scores > 0, scores * (1 + offset), scores * (1 - offset))
        # scores = torch.where(scores > 0, scores * (1 + offset), scores)  #-- advantage_offset v2
        #-- advantage_offset v3 and v4. Adaptively adjust advantage_offset based on the number of correct responses for each prompt.
        positive_offset = torch.ones(scores.size(0))
        negative_offset = torch.ones(scores.size(0))
        prompt_uid2vals = defaultdict(list)
        for idex_, item_ in enumerate(data):
            uid_ = item_.non_tensor_batch["uid"]
            acc_ = item_.non_tensor_batch["acc"]
            prompt_uid2vals[uid_].append([idex_, acc_])

        for uid_, chunk in prompt_uid2vals.items():
            # Extract responses for this question
            idex = np.array(chunk)[:,0]
            acc = np.array(chunk)[:,1]
            num_correct = np.sum(acc)
            num_incorrect = len(acc) - num_correct
            # if num_correct < num_incorrect: # v3, v4, v5 increase positive advantage for prompts who have more incorrect responses     
            #     positive_offset[idex] = 1 + offset * (num_incorrect - num_correct)
            # if num_correct > num_incorrect: # v4 only decrease positive advantage for prompts who have more correct responses
            #     positive_offset[idex] = 1 - offset * (num_correct - num_incorrect)
            # if num_correct > num_incorrect: # v5 only increase negative advantage for prompts who have more correct responses
            #     negative_offset[idex] = 1 + offset * (num_correct - num_incorrect)
            # v6
            if num_correct < num_incorrect:    
                negative_offset[idex] = 1 - offset * (num_incorrect - num_correct)
            if num_correct > num_incorrect: 
                positive_offset[idex] = 1 - offset * (num_correct - num_incorrect)
            # v7
            # if num_correct < num_incorrect:    
            #     negative_offset[idex] = 1 - offset * (num_incorrect - num_correct)
            
        
        scores = torch.where(scores > 0, scores * positive_offset.unsqueeze(-1), scores)  # only change positive advantage.
        scores = torch.where(scores < 0, scores * negative_offset.unsqueeze(-1), scores)  

    data.batch["advantages"] = scores
    data.batch["returns"] = scores

    # filter_wrong_in_resp
    if filter_wrong_in_resp:
        kept_index = []
        for id in id2index:
            cur_score = np.array(id2score[id])
            cur_index = np.array(id2index[id])
            correct_index = cur_index[cur_score == 1]
            wrong_index = cur_index[cur_score != 1]
            rollout_n = len(cur_index)
            num_correct = len(correct_index)
            if (rollout_n - num_correct) / (num_correct + epsilon) > wr_corr_max_ratio:
                kept_index.extend(correct_index.tolist())
                kept_index.extend(wrong_index[:num_correct * wr_corr_max_ratio].tolist())
            else:
                kept_index.extend(cur_index.tolist())
        return data[kept_index]
    
    return data


class RayDAPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        from omegaconf import OmegaConf

        from verl.utils.tracking import Tracking

        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        timing_raw = defaultdict(float)
        batch = None
        gen_distinct_batch = None
        already_kept_prompt_uids = []
        num_prompt_in_batch = 0
        num_prompt_in_gen_batch, num_distinct_filter = 0, 0
        num_gen_batches = 0
        stop_fork = False
        prompt_uid2distinct_history = {}
        prompt_uid2num_correct_history, prompt_uid2correct_times_history = {}, {}
        prompt_uid2correct_res_history = {}

        # Debug history ===========================
        # import pandas as pd
        # parquet_path = "/opt/tiger/llm/verl/data/dapo-math-17k.parquet"   # <- change this
        # df = pd.read_parquet(parquet_path)

        # # extract extra_info.index robustly (works if the cell is a dict or a JSON string)
        # def get_idx(cell):
        #     if isinstance(cell, dict):
        #         return cell.get("index")
        #     if isinstance(cell, str):
        #         try:
        #             return json.loads(cell).get("index")
        #         except Exception:
        #             return None
        #     return None

        # idx_series = df["extra_info"].map(get_idx).dropna()

        # # randomly assign 0 or 1 to each index (set seed for reproducibility if you want)
        # rng = np.random.default_rng(seed=42)  # remove/modify seed if not desired
        # values = rng.integers(0, 2, size=len(idx_series)).tolist()

        # prompt_uid2correct_times_history = dict(zip(idx_series.astype(str), values))
        # Debug history ===========================

        enable_distint_res_filter = False
        enable_correct_gen_temp = False
        enable_replace_correct_from_history = False
        advantage_offset = None
        if "Qwen2.5-Math-7B" in self.config.actor_rollout_ref.model.path:
            enable_filter_groups = False
        else:
            enable_filter_groups = self.config.algorithm.filter_groups.enable

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                do_profile = self.global_steps in (self.config.trainer.profile_steps or [])
                if do_profile:
                    self.actor_rollout_wg.start_profile()
                    if self.use_reference_policy:
                        self.ref_policy_wg.start_profile()
                    if self.use_critic:
                        self.critic_wg.start_profile()
                    if self.use_rm:
                        self.rm_wg.start_profile()

                metrics = {}

                new_batch: DataProto = DataProto.from_single_dict(batch_dict)
                # new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
                # new_batch.non_tensor_batch["uid"] = np.array([
                #                                             hashlib.sha256(prompt.encode('utf-8')).hexdigest()
                #                                             for prompt in new_batch.non_tensor_batch["question"]
                #                                             ], dtype=object)
                new_batch.non_tensor_batch["uid"] = np.array([
                                                            item["index"]
                                                            for item in new_batch.non_tensor_batch["extra_info"]
                                                            ], dtype=object)

                # select the prompt based on distinct number. Probability is distinct / n_rollout
                if len(prompt_uid2distinct_history) > 0:
                    kept_prompt_uids = [uid for uid in new_batch.non_tensor_batch["uid"]
                                        if uid not in prompt_uid2distinct_history or
                                        random.random() >= (prompt_uid2distinct_history[uid] / self.config.actor_rollout_ref.rollout.n)**0.5]
                    num_prompt_in_gen_batch += len(kept_prompt_uids)
                    kept_traj_idxs = []
                    for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
                        if traj_from_prompt_uid in kept_prompt_uids and traj_from_prompt_uid not in already_kept_prompt_uids:
                            kept_traj_idxs.append(idx)

                    already_kept_prompt_uids += kept_prompt_uids
                    new_gen_batch = new_batch[kept_traj_idxs]
                    gen_distinct_batch = new_gen_batch if gen_distinct_batch is None else DataProto.concat([gen_distinct_batch, new_gen_batch])

                    gen_prompt_bsz = self.config.data.gen_batch_size
                    if num_prompt_in_gen_batch < gen_prompt_bsz:
                        print(f"{num_prompt_in_gen_batch=} < {gen_prompt_bsz=}, keep selecting...")
                        num_distinct_filter += 1
                        continue
                    else:
                        # Align the batch
                        new_batch = gen_distinct_batch[:gen_prompt_bsz]

                    metrics["train/num_distinct_filter"] = num_distinct_filter
                    gen_distinct_batch = None
                    already_kept_prompt_uids = []
                    num_prompt_in_gen_batch, num_distinct_filter = 0, 0

                num_gen_batches += 1

                # pop those keys for generation
                if "multi_modal_data" in new_batch.non_tensor_batch.keys():
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
                    )
                else:
                    gen_batch = new_batch.pop(
                        batch_keys=["input_ids", "attention_mask", "position_ids"],
                        non_tensor_batch_keys=["raw_prompt_ids"],
                    )

                is_last_step = self.global_steps >= self.total_training_steps

                with marked_timer("step", timing_raw):
                    # generate a batch
                    with marked_timer("gen", timing_raw, "red"):
                        if stop_fork is True:
                            new_extra_args = deepcopy(self.config.actor_rollout_ref.rollout.extra_args)
                            new_extra_args["fork_top_percent"] = 0
                            new_extra_args["fork_bottom_percent"] = 0
                            self.actor_rollout_wg.update_sampling_params(extra_args=new_extra_args)
                        
                        if len(prompt_uid2correct_times_history) > 0:
                            customize_temperature = [self.config.actor_rollout_ref.rollout.temperature] * len(gen_batch)
                            for i, uid_ in enumerate(new_batch.non_tensor_batch["uid"]):
                                if uid_ in prompt_uid2correct_times_history:
                                    customize_temperature[i] = min(self.config.algorithm.ours.max_temperature, customize_temperature[i] + self.config.actor_rollout_ref.rollout.customize_temperature_step * prompt_uid2correct_times_history[uid_])
                                    # customize_temperature[i] += self.config.actor_rollout_ref.rollout.customize_temperature_step * prompt_uid2correct_times_history[uid_]
                            metrics["critic/customize_temperature_mean"] = np.mean(customize_temperature)
                            metrics["critic/customize_temperature_max"] = np.max(customize_temperature)
                            # add customize_temperature to gen_batch
                            gen_batch.non_tensor_batch["customize_temperature"] = np.array(customize_temperature)
                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                            gen_batch.non_tensor_batch.pop("customize_temperature")
                        else:
                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                        timing_raw.update(gen_batch_output.meta_info["timing"])
                        gen_batch_output.meta_info.pop("timing", None)

                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: # False
                        with marked_timer("gen_max", timing_raw, "red"):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            new_batch = new_batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(new_batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                            new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                            new_batch.batch["reward_baselines"] = reward_baseline_tensor

                            del gen_baseline_batch, gen_baseline_output

                    # new_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
                    # repeat to align with repeated responses in rollout
                    new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    new_batch = new_batch.union(gen_batch_output)

                    #TODO: set advantnage for all correct data, remove negative advantage responses.
                    with marked_timer("reward", timing_raw, "yellow"):
                        # compute scores. Support both model and function-based.
                        # We first compute the scores using reward model. Then, we call reward_fn to combine
                        # the results from reward model and rule-based results.
                        if self.use_rm: # False
                            # we first compute reward model score
                            reward_tensor = self.rm_wg.compute_rm_score(new_batch)
                            new_batch = new_batch.union(reward_tensor)

                        # we combine with rule-based rm
                        reward_extra_infos_dict: dict[str, list]
                        try:
                            reward_result = self.reward_fn(new_batch, return_dict=True)
                            reward_tensor = reward_result["reward_tensor"]
                            reward_extra_infos_dict = reward_result.get("reward_extra_info", {})
                        except Exception as e:
                            print(f"Error in reward_fn: {e}")
                            reward_tensor = self.reward_fn(new_batch)
                            reward_extra_infos_dict = {}

                        new_batch.batch["token_level_scores"] = reward_tensor

                        if reward_extra_infos_dict:
                            new_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        # compute rewards. apply_kl_penalty if available
                        if self.config.algorithm.use_kl_in_reward:
                            new_batch, kl_metrics = apply_kl_penalty(new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
                            metrics.update(kl_metrics)  # TODO: This will be cleared if we use multiple genenration batches
                        else:
                            new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

                    #update prompt_uid2distinct_history, prompt_uid2num_correct_history, prompt_uid2correct_times_history, all_correct, all_incorrect and after each generation
                    metrics.update(compute_utils_metrics(batch=new_batch, rollout_n=self.config.actor_rollout_ref.rollout.n)) #CHANGE
                    if enable_distint_res_filter:
                        prompt_uid2distinct_history.update(metrics["utils/prompt_uid2distinct"])

                    if enable_correct_gen_temp:
                        prompt_uid2num_correct_history.update(metrics["utils/prompt_uid2num_correct"])
                        for uid_, num in metrics["utils/prompt_uid2num_correct"].items():
                            if uid_ in prompt_uid2correct_times_history.keys():
                                if num >= self.config.actor_rollout_ref.rollout.n - self.config.data.correct_threshold:
                                    prompt_uid2correct_times_history[uid_] += 1
                            else:
                                prompt_uid2correct_times_history[uid_] = 1 if num >= self.config.actor_rollout_ref.rollout.n - self.config.data.correct_threshold else 0

                    if enable_replace_correct_from_history:
                        prompt_uid2idx_acc = defaultdict(list)
                        for idx, (uid, acc) in enumerate(zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch["acc"])):
                            prompt_uid2idx_acc[uid].append([idx, acc])

                        kept_idx = []
                        replace_num_prompts = 0
                        history_data = None
                        for uid, idx_acc in prompt_uid2idx_acc.items():
                            idx = [item[0] for item in idx_acc]
                            acc = [item[1] for item in idx_acc]

                            # debug
                            # prompt_uid2correct_res_history[uid] = new_batch[0:1]

                            corr_idx = [i for i, a in zip(idx, acc) if a]
                            wr_idx = [i for i, a in zip(idx, acc) if not a]
                            tar_num_corr = int(self.config.actor_rollout_ref.rollout.n * self.config.algorithm.ours.replace_corr_ratio)
                            
                            if len(corr_idx) < tar_num_corr and prompt_uid2correct_res_history.get(uid) is not None:
                                min_corr4wr = tar_num_corr if self.config.algorithm.ours.replace_corr_ratio < 0.5 else int(tar_num_corr / 2)
                                if len(corr_idx) == 0: # skip all incorrect data. Mark as "replace_correct_history_v2"
                                    continue
                                # if len(corr_idx) == 0 and len(prompt_uid2correct_res_history.get(uid)) < min_corr4wr: # for all wrong prompt, collect min_corr4wr correct responses for training.
                                #     continue
                                cur_need_history = prompt_uid2correct_res_history.get(uid)[-tar_num_corr+len(corr_idx):]
                                kept_idx.extend(corr_idx + wr_idx[len(cur_need_history):])
                                history_data = cur_need_history if history_data is None else DataProto.concat([history_data, cur_need_history])
                                replace_num_prompts += 1
                            else:
                                kept_idx.extend(idx)

                            # save after fetch. First two try save before fetch, so we repeat the correct responses to increase num_correct in batch. I named it as repeat_corr in log.
                            if 0 < len(corr_idx) <= tar_num_corr + 3: # only save if there may be responses less than tar_num_corr
                                new_corr = new_batch[corr_idx[0:tar_num_corr]]
                                if prompt_uid2correct_res_history.get(uid) is None or len(new_corr) >= tar_num_corr:
                                    prompt_uid2correct_res_history[uid] = new_batch[corr_idx[0:tar_num_corr]]
                                else:  # A queue to store correct responses.
                                    old_corr = prompt_uid2correct_res_history.get(uid)
                                    if len(new_corr) + len(old_corr) > tar_num_corr:
                                        prompt_uid2correct_res_history[uid] = DataProto.concat([old_corr[len(new_corr) + len(old_corr) - tar_num_corr:], new_corr])
                                    else:
                                        prompt_uid2correct_res_history[uid] = DataProto.concat([old_corr, new_corr])

                        new_batch = DataProto.concat([history_data, new_batch[kept_idx]]) if history_data is not None else new_batch
                        metrics["ours/replace_history_data_size"] = len(history_data) if history_data is not None else 0
                        metrics["ours/replace_num_prompts"] = replace_num_prompts

                    if not enable_filter_groups:
                        batch = new_batch
                    else:  # NOTE: When prompts after filtering is less than train batch size,
                        # we skip to the next generation batch
                        metric_name = self.config.algorithm.filter_groups.metric
                        if metric_name == "seq_final_reward":
                            # Turn to numpy for easier filtering
                            new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
                        elif metric_name == "seq_reward":
                            new_batch.non_tensor_batch["seq_reward"] = new_batch.batch["token_level_scores"].sum(dim=-1).numpy()

                        # Collect the sequence reward for each trajectory
                        prompt_uid2metric_vals = defaultdict(list)
                        for uid, metric_val in zip(new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]):
                            prompt_uid2metric_vals[uid].append(metric_val)

                        prompt_uid2metric_std = {}
                        for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
                            prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)

                        kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0 or len(prompt_uid2metric_vals[uid]) == 1]
                        if self.config.algorithm.ours.reactivate_all_correct: # add all correct data into training batch
                            all_correct_prompt_uids = [uid for uid, acc in prompt_uid2metric_vals.items() if len(acc) == sum(np.array(acc) == True)]
                            kept_prompt_uids.extend(all_correct_prompt_uids)

                        num_prompt_in_batch += len(kept_prompt_uids)

                        kept_traj_idxs = []
                        for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
                            if traj_from_prompt_uid in kept_prompt_uids:
                                kept_traj_idxs.append(idx)

                        new_batch = new_batch[kept_traj_idxs]
                        batch = new_batch if batch is None else DataProto.concat([batch, new_batch])

                        prompt_bsz = self.config.data.train_batch_size
                        if num_prompt_in_batch < prompt_bsz:
                            print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
                            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
                            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                                print(f"{num_gen_batches=}. Keep generating...")
                                progress_bar.update(1)
                                continue
                            else:
                                raise ValueError(f"{num_gen_batches=} >= {max_num_gen_batches=}." + " Generated too many. Please check if your data are too difficult." + " You could also try set max_num_gen_batches=0 to enable endless trials.")
                        else:
                            # Align the batch
                            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
                            batch = batch[:traj_bsz]

                    # === Updating ===

                    batch.batch["response_mask"] = compute_response_mask(batch)

                    # Balance the number of valid tokens across DP ranks.
                    # NOTE: This usually changes the order of data in the `batch`,
                    # which won't affect the advantage calculation (since it's based on uid),
                    # but might affect the loss calculation (due to the change of mini-batching).
                    # TODO: Decouple the DP balancing and mini-batching.
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # compute global_valid tokens
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    # recompute old_log_probs
                    with marked_timer("old_log_prob", timing_raw, "blue"):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        entropys = old_log_prob.batch["entropys"]
                        response_masks = batch.batch["response_mask"]
                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                        old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
                        metrics.update(old_log_prob_metrics)
                        old_log_prob.batch.pop("entropys")
                        batch = batch.union(old_log_prob)

                    if metrics.get("actor/entropy") > self.config.actor_rollout_ref.rollout.stop_fork_entropy and self.config.actor_rollout_ref.rollout.stop_fork is True:
                        stop_fork = True
                        print(f"Stop forking after step {self.global_steps}")

                    if self.use_reference_policy:
                        # compute reference log_prob
                        with marked_timer("ref", timing_raw, "olive"):
                            ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # compute values
                    if self.use_critic:
                        with marked_timer("values", timing_raw, "cyan"):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with marked_timer("adv", timing_raw, "brown"):
                        # compute advantages, executed on the driver process
                        norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
                        if self.config.algorithm.ours.reactivate_all_correct or self.config.algorithm.ours.filter_wrong_in_resp or advantage_offset is not None:
                            batch = compute_encourage_advantage(
                                batch,
                                offset=advantage_offset,
                                reactivate_all_correct=self.config.algorithm.ours.reactivate_all_correct,
                                filter_wrong_in_resp=self.config.algorithm.ours.filter_wrong_in_resp,
                                wr_corr_max_ratio=self.config.algorithm.ours.wr_corr_max_ratio,
                                norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                            )
                        else:
                            batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                            )
                        if len(batch) == 0:
                            continue

                    # update critic
                    if self.use_critic:
                        with marked_timer("update_critic", timing_raw, "pink"):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    # implement critic warmup
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # update actor
                        with marked_timer("update_actor", timing_raw, "red"):
                            batch_padded, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)
                            actor_output = self.actor_rollout_wg.update_actor(batch_padded)
                            actor_output = unpad_dataproto(actor_output, pad_size=pad_size)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    # validate
                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with marked_timer("testing", timing_raw, "green"):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                        with marked_timer("save_checkpoint", timing_raw, "green"):
                            self._save_checkpoint()

                # collect metrics
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, rollout_n=self.config.actor_rollout_ref.rollout.n)) #CHANGE
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # TODO: implement actual tflpo and theoretical tflpo
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                timing_raw = defaultdict(float)  # clear timing

                # if enable_distint_res_filter:
                #     prompt_uid2distinct_history.update(metrics["utils/prompt_uid2distinct"])

                # if enable_correct_gen_temp:
                #     prompt_uid2num_correct_history.update(metrics["utils/prompt_uid2num_correct"])
                #     for uid_, num in metrics["utils/prompt_uid2num_correct"].items():
                #         if uid_ in prompt_uid2correct_times_history.keys():
                #             if num >= self.config.actor_rollout_ref.rollout.n - self.config.data.correct_threshold:
                #                 prompt_uid2correct_times_history[uid_] += 1
                #         else:
                #             prompt_uid2correct_times_history[uid_] = 1 if num >= self.config.actor_rollout_ref.rollout.n - self.config.data.correct_threshold else 0

                if metrics.get("utils/prompt_uid2distinct") is not None:
                    metrics.pop("utils/prompt_uid2distinct")
                if metrics.get("utils/prompt_uid2num_correct") is not None:
                    metrics.pop("utils/prompt_uid2num_correct")

                metrics["train/num_gen_batches"] = num_gen_batches
                batch = None
                num_prompt_in_batch = 0
                num_gen_batches = 0

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if do_profile:
                    self.actor_rollout_wg.stop_profile()
                    if self.use_reference_policy:
                        self.ref_policy_wg.stop_profile()
                    if self.use_critic:
                        self.critic_wg.stop_profile()
                    if self.use_rm:
                        self.rm_wg.stop_profile()

                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return
                
                if self.config.data.enable_res_filter == True and metrics.get("critic/pred/invalid/mean") < 1:
                    enable_distint_res_filter = True
                    try:
                        if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                            with open(f'{self.config.trainer.default_local_dir}/prompt_uid2distinct_history_step{self.global_steps}.json', 'w') as f:
                                json.dump(prompt_uid2distinct_history, f, indent=4)
                    except:
                        print("error in save prompt_uid2distinct_history")

                if self.config.algorithm.ours.enable_replace_correct_from_history == True and metrics.get("critic/pred/invalid/mean") < 1:
                    enable_replace_correct_from_history = True
                    try:
                        if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                            with open(f'{self.config.trainer.default_local_dir}/prompt_uid2correct_res_history_step{self.global_steps}.json', 'w') as f:
                                json.dump(prompt_uid2correct_res_history, f, indent=4)
                    except:
                        print("error in save prompt_uid2distinct_history")

                if self.config.algorithm.ours.advantage_offset is not None and metrics.get("critic/pred/invalid/mean") < 1:
                    advantage_offset = self.config.algorithm.ours.advantage_offset
                
                if self.config.data.enable_correct_gen_temp == True and metrics.get("critic/pred/invalid/mean") < 1:
                    enable_correct_gen_temp = True
                    try:
                        if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
                            with open(f'{self.config.trainer.default_local_dir}/prompt_uid2correct_times_history_step{self.global_steps}.json', 'w') as f:
                                json.dump(prompt_uid2correct_times_history, f, indent=4)
                            with open(f'{self.config.trainer.default_local_dir}/prompt_uid2num_correct_history_step{self.global_steps}.json', 'w') as f:
                                json.dump(prompt_uid2num_correct_history, f, indent=4)
                    except:
                        print("error in save prompt_uid2correct_times_history")

                if "Qwen2.5-Math-7B" in self.config.actor_rollout_ref.model.path and metrics.get("critic/pred/invalid/mean") < 1:
                    enable_filter_groups = self.config.algorithm.filter_groups.enable

                # automative move checkpoint to hdfs
                if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0) and self.config.trainer.save_hdfs_dir is not None:
                    from concurrent.futures import ThreadPoolExecutor
                    import shutil
                    import os
                    pool = ThreadPoolExecutor(max_workers=2)

                    def move_task():
                        try:
                            print("\n moving to hdfs \n")
                            os.makedirs(f"{self.config.trainer.save_hdfs_dir}", exist_ok=True)
                            shutil.move(f"{self.config.trainer.default_local_dir}/global_step_{self.global_steps-1}", f"{self.config.trainer.save_hdfs_dir}")
                            if enable_distint_res_filter is True:
                                shutil.move(f"{self.config.trainer.default_local_dir}/prompt_uid2distinct_history_step{self.global_steps}.json", f"{self.config.trainer.save_hdfs_dir}")
                            if enable_correct_gen_temp is True:
                                shutil.move(f"{self.config.trainer.default_local_dir}/prompt_uid2correct_times_history_step{self.global_steps}.json", f"{self.config.trainer.save_hdfs_dir}")
                                shutil.move(f"{self.config.trainer.default_local_dir}/prompt_uid2num_correct_history_step{self.global_steps}.json", f"{self.config.trainer.save_hdfs_dir}")
                            if enable_replace_correct_from_history is True:
                                shutil.move(f"{self.config.trainer.default_local_dir}/prompt_uid2correct_res_history_step{self.global_steps}.json", f"{self.config.trainer.save_hdfs_dir}")
                        except:
                            print("error in move checkpoint")

                    def delete_task():
                        try:
                            # shutil.rmtree(f"{self.config.trainer.save_hdfs_dir}/global_step_{self.global_steps - 2*self.config.trainer.save_freq-1}")
                            shutil.rmtree(f"{self.config.trainer.save_hdfs_dir}/global_step_{self.global_steps - 2*self.config.trainer.save_freq}")
                            if enable_distint_res_filter is True:
                                os.remove(f"{self.config.trainer.save_hdfs_dir}/prompt_uid2distinct_history_step{self.global_steps - 2*self.config.trainer.save_freq}.json")
                            if enable_correct_gen_temp is True:
                                os.remove(f"{self.config.trainer.save_hdfs_dir}/prompt_uid2correct_times_history_step{self.global_steps - 2*self.config.trainer.save_freq}.json")
                                os.remove(f"{self.config.trainer.save_hdfs_dir}/prompt_uid2num_correct_history_step{self.global_steps - 2*self.config.trainer.save_freq}.json")
                            if enable_replace_correct_from_history is True:
                                os.remove(f"{self.config.trainer.save_hdfs_dir}/prompt_uid2correct_res_history_step{self.global_steps - 2*self.config.trainer.save_freq}.json")
                        except:
                            print("error in delete checkpoint")

                    fut = pool.submit(move_task)
                    def log_error(f):
                        if f.exception():
                            print("Background error:", f.exception())
                    fut.add_done_callback(log_error)
                    fut = pool.submit(delete_task)
                    fut.add_done_callback(log_error)

                progress_bar.update(1)
                self.global_steps += 1
