# coding=utf-8
import logging

from gpatch.core.utils import clear_memory, print_with_rank_and_datetime
logging.getLogger("httpx").setLevel(logging.WARNING)
from functools import partial
from typing import List, Dict, Any, Tuple
from typing_extensions import override
from packaging.version import Version

import torch
from torch import Tensor

from megatron.core import mpu, package_info
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.training import get_args
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.training import print_rank_0
from megatron.training.utils import unwrap_model
from megatron.core.utils import divide

from megatron_datasets.tasks.math_rl_v3.ppo_actor_dataset import build_train_valid_test_datasets, DataCollator
from megatron_datasets.utils import get_iterator

from gpatch.training.v3.ppo_actor import train_ppo_actor_v3
from gpatch.training.v3.default_model_provider import default_actor_model_provider
from gpatch.core.aligner_helper import (
    average_losses_across_data_parallel_group,
    get_gbs_batches_seqlen,
    get_iterator_k_split_list,
    get_last_rank,
    get_max_seqlen_within_dp,
    get_max_seqlen_within_ep,
    masked_mean,
    from_parallel_logits_to_logprobs,
    )
from gpatch.core.ppo_helper import (
    vocab_parallel_entropy, 
    calculate_kl_loss,
    _VocabParallelEntropy,
    )
from gpatch.core.transformer.transformer_config import GpatchTransformerConfig
from gpatch.core.device_type import is_wxacc1
from gpatch.core.parallel_state import is_mp_and_cp_head, get_mp_and_cp_size
from gpatch.core.models.gpt import (
    GptPpoActorModel as _GptPpoActorModel,
    GptPpoRmCriticClientV3,
    GptPpoSamplerClientV3,
    GptPpoGenRmClientV3,
)
from gpatch.core.tensor_parallel.mappings import all_gather_to_context_parallel_region
from gpatch.patch_mcore import init_gpatch_for_mcore

from tasks.math_rl_v3.args import get_tasks_args
from tasks.math_rl_v3.sp import get_ppo_prompt_format
from tasks.math_rl_v3.ppo_sampling import filter_samplings
from tasks.math_rl_v3.math_rl_actor_trainer import MathRLActorTrainer
from tasks.math_rl_v3.dpp_downsample2 import (
    memory_efficient_dpp_gaussian_torch,
    memory_efficient_dpp_gaussian_torch_v2,
    adaptive_sigma_selection,
)

model_provider = default_actor_model_provider

class GptPpoActorModel(_GptPpoActorModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        args = get_args()
        self.token_dropout_strategy = args.ppo_token_dropout_strategy
        self.use_gspo = args.use_gspo
    
    @override
    def get_loss_and_metrics(
        self,
        batch: List[Dict[str, Any]],
        num_microbatches: int,
        forward_only: bool,
        iteration: int,
        total_iters: int,
    ):
        seq_length = get_gbs_batches_seqlen(batch, self.pad_to_multi_of)
        if self.dynamic_mbs_target_seq is not None:
            seq_length = get_max_seqlen_within_dp(seq_length)
        else:
            seq_length = get_max_seqlen_within_ep(seq_length)
        dynamic_num_microbatches = 0
        dynamic_mbs = 0
        batch_size = len(batch)

        if self.dynamic_mbs_target_seq is not None:
            dynamic_mbs = self.dynamic_mbs_target_seq // seq_length * self.config.micro_batch_size
            if dynamic_mbs == 0:
                dynamic_mbs = 1
            dynamic_mbs = min(self.dynamic_mbs_limit, dynamic_mbs)
            while dynamic_mbs >= 1:
                if batch_size % dynamic_mbs == 0:
                    dynamic_num_microbatches = batch_size // dynamic_mbs
                    break
                else:
                    dynamic_mbs -= 1
            if dynamic_num_microbatches > 0:
                assert dynamic_num_microbatches * dynamic_mbs == batch_size, f"{dynamic_mbs=} {dynamic_num_microbatches=} {batch_size=} mismatch!"

        enable_dynamic_mbs = True if dynamic_num_microbatches > 0 else False
        print_with_rank_and_datetime(f"[TRAIN] {seq_length=} {batch_size=} {num_microbatches=} {enable_dynamic_mbs=} {dynamic_num_microbatches=} {dynamic_mbs=}", rank=0)
        data_iter = get_iterator_k_split_list(batch, dynamic_num_microbatches if enable_dynamic_mbs else num_microbatches)

        fwd_bwd_function = get_forward_backward_func()
        if self.use_grpo:
            forward_step_func = self.get_actor_grpo_forward_output_and_loss_func(seqlen=seq_length, iteration=iteration, total_iters=total_iters)
        else:
            forward_step_func = self.get_actor_forward_output_and_loss_func()
        losses_reduced_per_micro_batch = fwd_bwd_function(
            forward_step_func=forward_step_func,
            data_iterator=data_iter,
            model=self.model,
            num_microbatches=dynamic_num_microbatches if enable_dynamic_mbs else num_microbatches,
            forward_only=forward_only,
            seq_length=seq_length,
            decoder_seq_length=seq_length,
            micro_batch_size=dynamic_mbs if enable_dynamic_mbs else self.config.micro_batch_size,
        )
        clear_memory()

        metrics = {"seq_length": seq_length}
        metrics_key = ["loss", "ppo_ratio", "ppo_ratio_clamped", "scaled_entropy"]
        if self.use_grpo:
            metrics_key += ["grpo_kl_loss"]
        for key in metrics_key:
            if losses_reduced_per_micro_batch:
                metric_mean = torch.stack(
                    [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch]).mean()
            else:
                metric_mean = torch.tensor(0.0, device=torch.cuda.current_device())
            torch.distributed.broadcast(metric_mean, get_last_rank())
            metrics[key] = metric_mean.cpu().item()
        return metrics["loss"], metrics

    @override
    def get_actor_grpo_forward_output_and_loss_func(self, seqlen: int, iteration: int, total_iters: int):
        
        def fwd_output_and_loss_func(seqlen: int, iteration: int, total_iters: int, data_iterator, model):
            batches: List[Dict[str, Any]] = next(data_iterator)

            batch, fwd_kwargs = self.prepare_data_for_grpo_loss(batches, seqlen)
            for key in ["mask", "advantages", "prev_log_probs", "ref_log_probs", "target"]:
                assert key in batch

            parallel_logits = model(**fwd_kwargs)
            if isinstance(parallel_logits, tuple):
                parallel_logits = parallel_logits[0]
            assert isinstance(parallel_logits, torch.Tensor)

            def loss_func(iteration: int, total_iters: int, parallel_logits):
                parallel_logits = parallel_logits.float()
                mask = batch["mask"]
                advantages = batch["advantages"]
                prev_log_probs = batch["prev_log_probs"]
                ref_log_probs = batch["ref_log_probs"]
                assert advantages.dtype == torch.float32
                assert prev_log_probs.dtype == torch.float32
                target = batch["target"]

                parallel_logits_clone = parallel_logits.clone().detach()

                assert not self.config.cross_entropy_loss_fusion
                curr_log_probs = from_parallel_logits_to_logprobs(
                        vocab_parallel_logits=parallel_logits, target=target)

                if self.use_gspo:
                    # compute sequence-level importance ratio:
                    # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
                    # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
                    # [B, N-1]
                    negative_approx_kl = curr_log_probs - prev_log_probs
                    # [B]
                    seq_lengths = torch.sum(mask, dim=-1).clamp(min=1)
                    # [B]
                    negative_approx_kl_seq = torch.sum(negative_approx_kl * mask, dim=-1) / seq_lengths

                    # Combined ratio at token level:
                    # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
                    # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
                    # [B, N-1]
                    log_seq_importance_ratio = curr_log_probs - curr_log_probs.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
                    log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)  # clamp for numerical stability

                    # finaly exp() to remove log
                    # [B, N-1]
                    ratios = log_seq_importance_ratio.exp()
                    assert ratios.shape == prev_log_probs.shape, f"{ratios.shape} != {prev_log_probs.shape}"
                else:
                    # Calculate clipped PPO surrogate loss function.
                    ratios = (curr_log_probs - prev_log_probs).exp()
                ratios_clamped = ratios.clamp(1.0 - self.ratio_eps, 1.0 + self.ratio_eps)

                loss1 = -advantages * ratios
                loss2 = -advantages * ratios_clamped
                clip_max_loss = torch.maximum(loss1, loss2)
                # ref from: https://arxiv.org/pdf/1912.09729
                if self.dual_clip_ratio_c is not None:
                    loss3 = -advantages * self.dual_clip_ratio_c
                    clip_min_loss = torch.min(loss3, clip_max_loss)
                    actor_loss = torch.where(advantages < 0, clip_min_loss, clip_max_loss)
                else:
                    actor_loss = clip_max_loss
                
                if self.token_dropout_strategy != 'all':
                    # Use per-token entropy
                    with torch.no_grad():
                        if mpu.get_context_parallel_world_size() > 1:
                            # entropy_unmasked [b, s]
                            entropy_unmasked = _VocabParallelEntropy.apply(parallel_logits_clone)
                            per_token_entropy = all_gather_to_context_parallel_region(entropy_unmasked)[:, :-1]
                        else:
                            per_token_entropy = _VocabParallelEntropy.apply(parallel_logits_clone)[:, :-1]
                        scaled_entropy = per_token_entropy.mean() if mask is None else masked_mean(per_token_entropy, mask)
                        # Normalize per-token entropy to be in the range [0, 1].
                        per_token_entropy = (per_token_entropy - per_token_entropy.min()) / (
                            per_token_entropy.max() - per_token_entropy.min() + 1e-8)
                        per_token_entropy = per_token_entropy.clamp(0, 1)

                    if self.token_dropout_strategy == 'd2s':
                        actor_loss_norm = torch.abs(actor_loss)
                        fusion = actor_loss_norm + per_token_entropy
                        fusion_masked = torch.where(mask.to(torch.bool), fusion, float('-inf'))

                        top_ratio = 0.2
                        k = int(float(mask.sum()) * top_ratio)
                        
                        # 全局topk
                        fusion_flat = fusion_masked.view(-1)
                        _, top_indices_flat = torch.topk(fusion_flat, k=k)
                        
                        # 创建mask
                        top_mask = torch.zeros_like(mask, dtype=mask.dtype, device=mask.device)
                        top_mask.view(-1)[top_indices_flat] = 1
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {top_mask.sum()=} {mask.sum()=}")
                        mask *= top_mask
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {mask.sum()=}")
                    elif self.token_dropout_strategy == 'linear-decrease':
                        actor_loss_norm = torch.abs(actor_loss)
                        fusion = actor_loss_norm + per_token_entropy
                        fusion_masked = torch.where(mask.to(torch.bool), fusion, float('-inf'))

                        ratio = max(min(float(iteration / (total_iters - 1 + 1e-8)), 1.0), 0.0)  # 变化比例，0到1之间

                        top_ratio_original = 0.2
                        top_ratio_new = 0.05

                        top_ratio = top_ratio_original * (1 - ratio) + top_ratio_new * ratio
                        k = int(float(mask.sum()) * top_ratio)
                        
                        # 全局topk
                        fusion_flat = fusion_masked.view(-1)
                        _, top_indices_flat = torch.topk(fusion_flat, k=k)
                        
                        # 创建mask
                        top_mask = torch.zeros_like(mask, dtype=mask.dtype, device=mask.device)
                        top_mask.view(-1)[top_indices_flat] = 1
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {top_mask.sum()=} {mask.sum()=}")
                        mask *= top_mask
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {mask.sum()=}")
                    elif self.token_dropout_strategy == 'linear-increase':
                        actor_loss_norm = torch.abs(actor_loss)
                        fusion = actor_loss_norm + per_token_entropy
                        fusion_masked = torch.where(mask.to(torch.bool), fusion, float('-inf'))

                        ratio = max(min(float(iteration / (total_iters - 1 + 1e-8)), 1.0), 0.0)  # 变化比例，0到1之间

                        ratio = 1.0 - ratio

                        top_ratio_original = 0.2
                        top_ratio_new = 0.05

                        top_ratio = top_ratio_original * (1 - ratio) + top_ratio_new * ratio
                        k = int(float(mask.sum()) * top_ratio)
                        
                        # 全局topk
                        fusion_flat = fusion_masked.view(-1)
                        _, top_indices_flat = torch.topk(fusion_flat, k=k)
                        
                        # 创建mask
                        top_mask = torch.zeros_like(mask, dtype=mask.dtype, device=mask.device)
                        top_mask.view(-1)[top_indices_flat] = 1
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {top_mask.sum()=} {mask.sum()=}")
                        mask *= top_mask
                        print_rank_0(f"[MATH][DEBUG] {iteration}/{total_iters} {mask.sum()=}")
                    elif self.token_dropout_strategy == 'dpp':
                        # 获取有效位置
                        valid_mask = mask.flatten()
                        valid_indices = torch.where(valid_mask.to(torch.bool))[0]
                        
                        if len(valid_indices) > 0:
                            dpp_ratio = 0.2
                            k = max(1, int(float(len(valid_indices)) * dpp_ratio))
                            with torch.no_grad():
                            
                                features = torch.stack([
                                        per_token_entropy.flatten()[valid_indices].detach(),
                                        advantages.flatten()[valid_indices].detach(),
                                        curr_log_probs.flatten()[valid_indices].detach(),
                                        prev_log_probs.flatten()[valid_indices].detach(),
                                    ], dim=1)
                                # print_rank_0(f"[MATH][DEBUG]{per_token_entropy.max()=} {per_token_entropy.min()=}")
                                # print_rank_0(f"[MATH][DEBUG]{advantages.max()=} {advantages.min()=} {advantages.shape=}")
                                # print_rank_0(f"[MATH][DEBUG]{curr_log_probs.max()=} {curr_log_probs.min()=} {curr_log_probs.shape=}")
                                # print_rank_0(f"[MATH][DEBUG]{prev_log_probs.max()=} {prev_log_probs.min()=} {prev_log_probs.shape=}")

                                features = (features - features.mean(dim=0, keepdim=True)) / (features.std(dim=0, keepdim=True) + 1e-8)
                                
                                # DPP采样
                                # sigma = 0.5
                                sigma = adaptive_sigma_selection(features, target_similarity=0.1)

                                dpp_indices_in_valid = memory_efficient_dpp_gaussian_torch(
                                    features, 
                                    sigma=sigma,  # 或使用自适应选择
                                    max_length=k
                                )
                                print_rank_0(f"[MATH][DEBUG] {sigma=} {len(dpp_indices_in_valid)=}")
                            
                            # 创建新mask：只保留选中的有效token
                            dpp_mask = torch.zeros_like(mask.flatten(), dtype=mask.dtype)
                            dpp_mask[valid_indices[dpp_indices_in_valid]] = 1
                            dpp_mask = dpp_mask.view_as(mask)
                            print_rank_0(f"[MATH][DEBUG]{dpp_mask.sum()=} {mask.sum()=}")
                            mask *= dpp_mask
                            print_rank_0(f"[MATH][DEBUG]{mask.sum()=}")
                    elif self.token_dropout_strategy == 'random':
                        valid_mask = mask.flatten()
                        valid_indices = torch.where(valid_mask.to(torch.bool))[0]
                        if len(valid_indices) > 0:
                            random_ratio = 0.2
                            k = int(float(mask.sum()) * random_ratio)
                            random_indices_in_valid = torch.randperm(int(mask.sum()))[:k]
                            random_mask = torch.zeros_like(mask.flatten(), dtype=mask.dtype, device=mask.device)
                            random_mask[valid_indices[random_indices_in_valid]] = 1
                            random_mask = random_mask.view_as(mask)
                            print_rank_0(f"[MATH][DEBUG]{random_mask.sum()=} {mask.sum()=}")
                            mask *= random_mask
                            print_rank_0(f"[MATH][DEBUG]{mask.sum()=}")

                    actor_loss = masked_mean(actor_loss, mask)
                else:
                    scaled_entropy = vocab_parallel_entropy(parallel_logits_clone, mask)
                    actor_loss = masked_mean(actor_loss, mask)
                
                loss = actor_loss - scaled_entropy * self.entropy_bonus

                with torch.no_grad():
                    ppo_ratio = masked_mean(ratios.detach(), mask)
                    ppo_ratio_clamped = masked_mean(ratios_clamped.detach(), mask)
                    scaled_entropy = scaled_entropy.detach()

                kl_loss = masked_mean(
                    calculate_kl_loss(
                        cur_log_probs=curr_log_probs,
                        ref_log_probs=ref_log_probs,
                        use_absolute_kl=False,  
                        use_low_var_kl=True,
                        clamp_kl_loss=self.dual_clip_ratio_c is not None,
                    ),
                    mask)
                loss = loss + kl_loss * self.config.grpo_kl_loss_beta

                reduced_actor_loss = average_losses_across_data_parallel_group([loss, kl_loss])

                if Version(package_info.__version__) < Version("0.12.1"):
                    bwd_loss = loss * self.config.context_parallel_size
                else:
                    bwd_loss = loss.clone()

                return (
                    bwd_loss,
                    {
                        "loss": reduced_actor_loss[0],
                        "ppo_ratio": ppo_ratio,
                        "ppo_ratio_clamped": ppo_ratio_clamped,
                        "scaled_entropy": scaled_entropy,
                        "grpo_kl_loss": reduced_actor_loss[1],
                    },
                )

            return parallel_logits, partial(loss_func, iteration, total_iters)

        return partial(fwd_output_and_loss_func, seqlen, iteration, total_iters)


def actor_provider(model, ref_model_state):
    args = get_args()

    actor_model = GptPpoActorModel(
        model=model,
        ref_model_state=ref_model_state,
        unwrap_model_func=unwrap_model,
        # PPO args
        forward_micro_batch_size=args.ppo_logps_fwd_micro_batch_size,
        ppo_rollout_temperature=args.ppo_rollout_temperature,
        # SMART-PAD args
        pad_to_multi_of=args.ppo_rollout_pad_to_multiple_of,
        pad_token_id=get_tokenizer()._tokenizer.pad_token_id,
        dynamic_mbs_target_seqlen=args.ppo_train_dynamic_mbs_target_seq,
        dynamic_mbs_limit=args.ppo_train_dynamic_mbs_limit
    )

    return actor_model



def sampler_client_provider():
    args = get_args()
    tokenizer = get_tokenizer()
    cli = GptPpoSamplerClientV3(
        ep_ips=args.ppo_sampler_ips,
        ep_ports=args.ppo_sampler_ports,
        timeout=args.ppo_sampler_client_timeout,
        update_timeout=args.ppo_sampler_client_update_timeout,
        rpc_max_retries=args.grpo_rpc_max_retries,
        unwrap_model_func=unwrap_model,
        infer_engine_impl=args.infer_engine_impl,
        update_weight_max_size_mb=args.update_weight_max_size_mb,
    )
    return cli



def rm_critic_client_provider():
    args = get_args()
    tokenizer = get_tokenizer()
    cli = GptPpoRmCriticClientV3(
        pad_token_id=tokenizer._tokenizer.pad_token_id,
        ep_ips=args.ppo_critic_ips,
        ep_ports=args.ppo_critic_ports,
        combine_rm_and_critic_server=args.combine_rm_and_critic_server,
        timeout=args.ppo_rm_critic_client_timeout,
        tokenizer=tokenizer,
        rpc_max_retries=args.grpo_rpc_max_retries,
        num_rm=args.ppo_num_rm,
        ppo_debug_fake_rm_critic=args.ppo_debug_fake_rm_critic,
        ppo_value_truncate_head=args.ppo_value_truncate_head,
    )
    return cli


def gen_rm_client_provider():
    args = get_args()
    assert args.use_gen_rm
    cli = GptPpoGenRmClientV3(
        ep_ips=args.ppo_gen_rm_ips,
        ep_ports=args.ppo_gen_rm_ports,
        timeout=args.ppo_gen_rm_client_timeout,
        rpc_max_retries=args.grpo_rpc_max_retries,
        unwrap_model_func=unwrap_model,
    )
    return cli


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()
    tokenizer = get_tokenizer()

    print_rank_0('> building train, validation, and test datasets ...')
    prompt_format, eos_token = get_ppo_prompt_format(args, tokenizer)
    if isinstance(prompt_format, list):
        for sub_pt in prompt_format:
            print_rank_0(f"building dataset with sub_prompt_format {sub_pt} eos_token {eos_token}")
    else:
        print_rank_0(f"building dataset with prompt_format {prompt_format} eos_token {eos_token}")

    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        args,
        tokenizer,
        dp_rank=mpu.get_data_parallel_rank(),
        dp_size=mpu.get_data_parallel_world_size(),
        prompt_format=prompt_format,
        eos_token=eos_token)
    print_rank_0(f"> finished creating datasets ...")

    collate_fn = DataCollator(tokenizer=tokenizer,
                              seq_len=args.seq_length,
                              resp_seq_len=args.ppo_resp_seq_len,
                              gen_left_pad=args.gen_left_pad)
    batch_size = args.ppo_rollout_micro_batch_size
    train_dataloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=args.num_workers,
        drop_last=True,
        pin_memory=True,
        collate_fn=collate_fn,
        prefetch_factor=args.px_dataloader_prefetch_factor,
    )

    eval_dataloader = None
    if valid_ds is not None:
        eval_dataloader = torch.utils.data.DataLoader(
            valid_ds,
            batch_size=args.ppo_eval_rollout_micro_batch_size,
            num_workers=args.num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=collate_fn,
            prefetch_factor=args.px_dataloader_prefetch_factor,
        )
    test_dataloader = None
    if test_ds is not None:
        test_dataloader = torch.utils.data.DataLoader(
            test_ds,
            batch_size=batch_size,
            num_workers=args.num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=collate_fn,
            prefetch_factor=args.px_dataloader_prefetch_factor,
        )
    print_rank_0(f"> finished creating dataloader ...")

    return get_iterator(train_dataloader), get_iterator(eval_dataloader), get_iterator(
        test_dataloader)


def rollout_get_batch(data_iterator):
    args = get_args()
    assert is_mp_and_cp_head(), f'只有 mp_head 会走到这里'

    # Broadcast data.
    assert data_iterator is not None
    data = next(data_iterator)

    tokens = data['input_ids']
    lpad_lens = data['lpad_lens']
    gt_label = data['gt_label']

    lpad_lens_list = lpad_lens.tolist()
    prompt_token_ids = []
    for i in range(len(tokens)):
        # tokens_prompt 的格式是：
        # tokens_prompt = {
        #     'prompt_token_ids': [0, 114, 514, 19, 19, 810, 1, 2, 2],
        # }
        # 之所以额外套一层，是因为 vllm asyncLLM 的输入 `vllm.inputs.data.TokensPrompt` 本身就是用 dict 来表示。
        # 虽然看着觉得有点别扭，但是兼容性更好。
        prompt_token_ids.append({
            'prompt_token_ids': tokens[i][:lpad_lens_list[i]].tolist(),
        })

    batch_data = {
        "prompt_token_ids":
        prompt_token_ids,
        "lpad_lens":
        lpad_lens,
        "gt_label":
        gt_label,
    }
    return batch_data


# 初始化MathRLActorTrainer时拿不到args，无法判断是rm还是gen-rm
# 通过extra_metric_info_provider判断
def extra_metric_info_provider():
    args = get_args()
    if args.use_gen_rm:
        extra_metric_info = [
            {
                "key_name": "rm_rewards",
                "dtype": torch.float32
            },
        ]
    else:
        extra_metric_info = [
            {
                'key_name': 'rm_rewards',
                'dtype': torch.float32
            },
            {
                'key_name': 'acc_rewards',
                'dtype': torch.float32
            },
            {
                'key_name': 'fmt_rewards',
                'dtype': torch.float32
            },
            {
                'key_name': 'sample_useful',
                'dtype': torch.bool
            },
        ]
    return extra_metric_info

if __name__ == "__main__":
    init_gpatch_for_mcore()
    train_valid_test_datasets_provider.is_distributed = True

    trainer = MathRLActorTrainer(extra_metric_info=extra_metric_info_provider)
    train_ppo_actor_v3(trainer,
                       model_provider,
                       actor_provider,
                       sampler_client_provider,
                       rm_critic_client_provider,
                       gen_rm_client_provider,
                       train_valid_test_datasets_provider,
                       rollout_get_batch,
                       filter_samplings,
                       ModelType.encoder_or_decoder,
                       extra_args_provider=get_tasks_args)
