import copy
import json
import math
import os
from functools import partial
from typing import Any, Dict, List

from roll.datasets.dataset import get_dataset
import ray
import torch
from codetiming import Timer
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.util.timer import _Timer

import datasets
from roll.datasets.chat_template import get_chat_template
from roll.datasets.collator import DataCollatorWithPaddingForPaddedKeys
from roll.distributed.executor.cluster import Cluster
from roll.distributed.scheduler.generate_scheduler import DynamicSamplingScheduler
from roll.distributed.scheduler.protocol import DataProto
from roll.models.model_providers import default_tokenizer_provider
from roll.pipeline.base_pipeline import BasePipeline
from roll.pipeline.rlvr.rlvr_config import RLVRConfig
from roll.utils.functionals import (
    compute_advantage,
    reduce_metrics,
    RunningMoments,
    get_sample_level_mask,
    reward_postprocess,
    compute_token_reward,
    agg_loss,
)
from roll.utils.kl_controller import get_kl_controller
from roll.utils.logging import get_logger
from roll.utils.metrics.metrics_manager import MetricsManager
from prompts_xml import prompt_stage2
from tensordict import TensorDict
from torch.nn.utils.rnn import pad_sequence
from ast import literal_eval
import re


logger = get_logger()

def process_stage1_output(s):
    s = str(s)
    pattern = r"'reason':\s*'((?:[^'\\]|\\.)*)'[,}]"
    reason = re.search(pattern, s)
    pattern = r"'strategy':\s*\"((?:[^\"\\]|\\.)*)\"[,}]"
    strategy = re.search(pattern, s)
    if not strategy:
        pattern = r"'strategy':\s*'((?:[^\'\\]|\\.)*)'[,}]"
        strategy = re.search(pattern, s)
    pattern = r"'mode':\s*'((?:[^'\\]|\\.)*)'[,}]"
    mode = re.search(pattern, s)

    my_dict = {
        "reason": reason.group(1) if reason else "",
        "strategy": strategy.group(1) if strategy else "",
        "mode": mode.group(1) if mode else ""
    }
    return my_dict


def process_stage2_output(s):
    s = str(s)
    action_type = re.search(r"'action_type'\s*:\s*'((?:[^'\\]|\\.)*)'", s)
    if not action_type:
        pattern = r"'action_type':\s*\"((?:[^\"\\]|\\.)*)\"[,}]"
        action_type = re.search(pattern, s)
    argument = re.search(r"'argument'\s*:\s*'((?:[^'\\]|\\.)*)'", s)
    if not argument:
        pattern = r"'argument':\s*\"((?:[^\"\\]|\\.)*)\"[,}]"
        argument = re.search(pattern, s)

    my_dict = {
        "action_type": action_type.group(1) if action_type else "",
        "argument": argument.group(1) if argument else ""
    }

    return my_dict




def extract_first_to_last_brace(s: str):
    if not isinstance(s, str):
        logger.info(f'Output Fault: {s}')
        return "None"
    start = s.find('{')
    end = s.find('}')
    if start == -1 or end == -1 or end < start:
        logger.info(f'Output Fault: {s}')
        return "None"
    return s[start:end+1]

def preprocess_dataset(dataset, prompt_len, encode_function, num_proc):
    # 处理数据
    logger.info(f"Begin : {dataset}")
    dataset = dataset.map(
        encode_function,
        batched=True,
        num_proc=num_proc,
        desc="Encoding dataset",
        load_from_cache_file=False,
    )
    # 过滤cutoff
    dataset = dataset.filter(
        lambda data_i: 5 < len(data_i["input_ids"]) <= prompt_len,
        num_proc=num_proc,
        desc="Filtering dataset",
    )
    logger.info(f"Filtering prompt len: {dataset}")
    logger.info(f"Encoding: {dataset}")
    return dataset


def get_encode_function(template_name, tokenizer):
    chat_template_func = get_chat_template(template_name, tokenizer)

    def encode_function(data_i):
        text_list = []
        if "message" in data_i:
            for messages in data_i["message"]:
                if isinstance(messages, str):
                    messages = json.loads(messages)
                text_list.append(chat_template_func(messages))
        elif "prompt" in data_i:
            for prompt in data_i["prompt"]:
                text_list.append(prompt)

        encodings = tokenizer(text_list)
        return encodings

    return encode_function

def update_dataset_domain(tag_2_domain: Dict[str, set[str]], row):
    if 'domain' in row and row['domain'] is not None:
        return row
    row["domain"] = tag_2_domain.get(row["tag"], "math_rule")
    return row

def query_filter_fn(data_list: List[DataProto], config: RLVRConfig) -> bool:
    """
    各domain的过滤规则可以自定义
    """
    response_level_rewards = [data.batch["response_level_rewards"] for data in data_list]
    if len(response_level_rewards) == 1:
        return True
    rewards = torch.cat(response_level_rewards, dim=0)

    domain = data_list[0].non_tensor_batch["domain"][0]
    query_filter_config = config.rewards[domain].query_filter_config

    if query_filter_config.type == "no_filter":
        return True
    elif query_filter_config.type == "mean_filter":
        threshold_up = query_filter_config.filter_args.get("threshold_up", math.inf)
        threshold_down = query_filter_config.filter_args.get("threshold_down", -1)
        if torch.mean(rewards) <= threshold_down or torch.mean(rewards) >= threshold_up:
            return False
    elif query_filter_config.type == "std_filter":
        std_threshold = query_filter_config.filter_args.get("std_threshold", -1)
        if torch.std(rewards) <= std_threshold:
            return False
    return True


class LHRLPipeline(BasePipeline):

    def __init__(self, pipeline_config: RLVRConfig):
        super().__init__(pipeline_config)
        self.pipeline_config = pipeline_config

        self.tokenizer = default_tokenizer_provider(model_args=self.pipeline_config.actor_train.model_args)


        logger.info(f'Use training dataset type: {self.pipeline_config.actor_train.data_args.dataset_type}')
        dataset = get_dataset(data_args=self.pipeline_config.actor_train.data_args)


        self.val_dataset = None
        if self.pipeline_config.validation.data_args.file_name:
            self.val_dataset = get_dataset(data_args=self.pipeline_config.validation.data_args)

        # 加上format，然后转ids的func
        template_name = (
            self.pipeline_config.global_template
            if self.pipeline_config.global_template
            else self.pipeline_config.actor_train.data_args.template
        )
        encode_function = get_encode_function(template_name, self.tokenizer)

        dataset = preprocess_dataset(
            dataset,
            self.pipeline_config.prompt_length,
            encode_function,
            num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
        )
        # update domain field
        dataset = dataset.map(
            partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
            num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
            desc="update_dataset_domain",
            load_from_cache_file=False
        )
        self.domain_datasets: Dict[str, datasets.Dataset] = {}
        for domain in self.pipeline_config.actor_train.data_args.domain_interleave_probs.keys():
            self.domain_datasets[domain] = dataset.filter(
                lambda example, dom: example["domain"] == dom,
                num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
                fn_kwargs={"dom": domain},
            )
            assert len(self.domain_datasets[domain]) > 0, f"domain dataset {domain} has no data"

        if self.val_dataset:
            self.val_dataset = preprocess_dataset(
                self.val_dataset,
                self.pipeline_config.prompt_length,
                encode_function,
                num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
            )
            self.val_dataset = self.val_dataset.map(
                partial(update_dataset_domain, self.pipeline_config.tag_2_domain),
                num_proc=self.pipeline_config.actor_train.data_args.preprocessing_num_workers,
                desc="update_val_dataset_domain",
                load_from_cache_file=False
            )

        assert 'domain' in dataset.column_names, "domain field should set in dataset"
        if self.val_dataset:
            assert 'domain' in self.val_dataset.column_names, "domain field should set in val dataset"

        self.kl_ctrl = get_kl_controller(
            init_kl_coef=self.pipeline_config.init_kl_coef,
            target_kl=self.pipeline_config.target_kl,
            kl_horizon=self.pipeline_config.kl_horizon,
        )

        assert self.pipeline_config.max_steps > 0, "max_steps must be greater than 0"
        self.pipeline_config.set_max_steps(max_steps=self.pipeline_config.max_steps)

        self.actor_train: Any = Cluster(
            name=self.pipeline_config.actor_train.name,
            worker_cls=self.pipeline_config.actor_train.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_train,
        )
        self.actor_infer: Any = Cluster(
            name=self.pipeline_config.actor_infer.name,
            worker_cls=self.pipeline_config.actor_infer.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.actor_infer,
        )
        self.reference: Any = Cluster(
            name=self.pipeline_config.reference.name,
            worker_cls=self.pipeline_config.reference.worker_cls,
            resource_manager=self.resource_manager,
            worker_config=self.pipeline_config.reference,
        )
        if self.pipeline_config.adv_estimator == "gae":  # TODO  adv_estimator是干嘛的
            self.critic: Any = Cluster(
                name=self.pipeline_config.critic.name,
                worker_cls=self.pipeline_config.critic.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=self.pipeline_config.critic,
            )
        
        # 这里加载了reward,写了的都会
        self.rewards: Dict[str, Any] = {
            key: Cluster(
                name=f"reward-{key}",
                worker_cls=worker_config.worker_cls,
                resource_manager=self.resource_manager,
                worker_config=worker_config,
            )
            for key, worker_config in self.pipeline_config.rewards.items()
        }

        domain_ratios = self.pipeline_config.actor_train.data_args.domain_interleave_probs
        self.generate_schedulers: Dict[str, DynamicSamplingScheduler] = {}
        self.domain_batch_size = {}
        domain_list = list(domain_ratios.keys())
        accumulated = 0
        for i, domain in enumerate(domain_list):
            if i == len(domain_list) - 1:
                domain_batch_size = self.pipeline_config.rollout_batch_size - accumulated
            else:
                domain_batch_size = int(domain_ratios[domain] * self.pipeline_config.rollout_batch_size)
            accumulated += domain_batch_size
            # TODO 这里没看懂具体作用 后续再看下
            generate_scheduler = DynamicSamplingScheduler.options(
                scheduling_strategy=NodeAffinitySchedulingStrategy(
                    node_id=ray.get_runtime_context().get_node_id(),
                    soft=False,
                )
            ).remote(pipeline_config=self.pipeline_config)
            mid = self.pipeline_config.mid_domain
            reward_clusters = {domain: self.rewards[domain]}
            if mid != "defalt":
                reward_clusters[mid] = self.rewards[mid]
            ray.get(
                generate_scheduler.set_scheduler.remote(
                    actor_cluster=self.actor_infer,
                    # reward_clusters={domain: self.rewards[domain],self.pipeline_config.mid_domain: self.rewards[self.pipeline_config.mid_domain]},
                    reward_clusters = reward_clusters,
                    dataset=self.domain_datasets[domain],
                    collect_fn_cls=DataCollatorWithPaddingForPaddedKeys,
                    collect_fn_kwargs=dict(max_length=self.pipeline_config.prompt_length, padding="max_length"),
                    response_filter_fn=lambda data_item, config: True,
                    query_filter_fn=query_filter_fn,
                    response_callback_fn=generate_scheduler.report_response.remote,
                    state=self.state.kv.get(f"scheduler_state_{domain}", None),
                )
            )
            self.generate_schedulers[domain] = generate_scheduler
            self.domain_batch_size[domain] = domain_batch_size

            assert domain_batch_size < len(self.domain_datasets[domain]), (f"domain_batch_size {domain_batch_size} must be "
                                                                           f"less than the number of domain datasets {len(self.domain_datasets[domain])}")

        if self.val_dataset:
            val_pipeline_config = copy.deepcopy(self.pipeline_config)
            val_pipeline_config.use_additional_prompts = False
            self.val_generate_scheduler = DynamicSamplingScheduler.options(
                scheduling_strategy=NodeAffinitySchedulingStrategy(
                    node_id=ray.get_runtime_context().get_node_id(),
                    soft=False,
                )
            ).remote(pipeline_config=val_pipeline_config)
        if self.val_dataset:
            ray.get(
                self.val_generate_scheduler.set_scheduler.remote(
                    actor_cluster=self.actor_infer,
                    reward_clusters=self.rewards,
                    dataset=self.val_dataset,
                    collect_fn_cls=DataCollatorWithPaddingForPaddedKeys,
                    collect_fn_kwargs=dict(max_length=self.pipeline_config.prompt_length, padding="max_length"),
                    response_filter_fn=lambda data_item, config: True,
                    query_filter_fn=lambda data_list, config: True,
                    response_callback_fn=self.val_generate_scheduler.report_response.remote,
                )
            )

        refs = []
        refs.extend(self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=False))
        ray.get(refs)
        refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True))
        refs = []
        for key, cluster in self.rewards.items():
            refs.extend(cluster.initialize(pipeline_config=self.pipeline_config, blocking=False))
        ray.get(refs)
        refs: List[ray.ObjectRef] = []
        refs.extend(self.actor_train.initialize(pipeline_config=self.pipeline_config, blocking=False))
        if self.pipeline_config.adv_estimator == "gae":
            refs.extend(self.critic.initialize(pipeline_config=self.pipeline_config, blocking=False))
        ray.get(refs)

        self.set_model_update_pair(
            src_cluster=self.actor_train,
            tgt_cluster=self.actor_infer,
            frequency=self.pipeline_config.actor_train.model_update_frequency,
        )

        if self.pipeline_config.adv_estimator == "gae":
            self.set_checkpoint_clusters(self.actor_train, self.critic)
        else:
            self.set_checkpoint_clusters(self.actor_train)

        self.running = {}
        for domain in self.rewards.keys():
            self.running[domain] = RunningMoments()

    @torch.no_grad()
    def run(self):

        metrics_mgr = MetricsManager()

        tps_timer = _Timer(window_size=5)
        actor_infer_timer = _Timer(window_size=5)
        actor_infer_response_timer = _Timer(window_size=5)
        actor_train_timer = _Timer(window_size=5)

        for global_step in range(self.pipeline_config.max_steps):
            if global_step <= self.state.step:
                global_step += 1
                continue
            logger.info(f"pipeline step {global_step} start...")

            metrics_mgr.clear_metrics()
            with tps_timer, Timer(name="step_total", logger=None) as step_total_timer:

                # 先model update，resume时不需要保存infer cluster的状态
                if self.pipeline_config.adv_estimator == "gae":
                    self.critic.offload_states(blocking=True)
                self.actor_train.offload_states(blocking=True)

                with Timer(name="step_model_update", logger=None) as step_model_update_timer:
                    model_update_metrics: Dict = self.model_update(global_step)
                    metrics_mgr.add_metrics(model_update_metrics)
                    metrics_mgr.add_metric("time/step_model_update", step_model_update_timer.last)

                if self.val_dataset and global_step % self.pipeline_config.eval_steps == 0:
                    with Timer(name="val_step", logger=None) as val_step_timer:
                        val_metrics = self.val(global_step)
                        metrics_mgr.add_metrics(val_metrics)
                        metrics_mgr.add_metric("time/val_step", val_step_timer.last)

                batch: DataProto = DataProto()
                batch.meta_info = {"global_step": global_step}

                # 要按domain group by生成对应的batch
                with actor_infer_timer, actor_infer_response_timer, Timer(
                    name="step_generate", logger=None
                ) as step_generate_timer:
                    domain_batches = {}
                    batch.meta_info["generation_config"] = self.actor_infer.worker_config.generating_args.to_dict()
                    self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info))
                    for reward_cluster in self.rewards.values():
                        reward_cluster.load_states()

                    # batch.meta_info["is_offload_states"] = False
                    scheduler_refs = {}
                    if self.pipeline_config.if_2_stage:
                        for domain, scheduler in self.generate_schedulers.items():
                            logger.info(f"目前的{domain}")=
                            stage1_config = copy.deepcopy(batch.meta_info["generation_config"])
                            stage1_config["num_return_sequences"] = self.pipeline_config.k1
                            
                            if self.pipeline_config.if_mid_reward:
                                batch.meta_info["domain"] = self.pipeline_config.mid_domain
                            else:
                                batch.meta_info["skip_reward"] = True
                            
                            batch.meta_info["is_offload_states"] = False
                            batch.meta_info['generation_config'] = stage1_config
                            stage1_batch_ref = scheduler.get_batch.remote(data=batch, batch_size=self.domain_batch_size[domain])
                            stage1_output: DataProto = ray.get(stage1_batch_ref)

                            if self.pipeline_config.if_mid_reward:
                                stage1_output.meta_info["domain"] = domain
                            selected_idx = torch.randint(0, self.pipeline_config.k1, (1,)).item()
                            stage2_outputs = []
                            chosed_samples = []
                            stage2_prompt_data = self.build_stage2_prompt(stage1_output) 
                            stage2_config = copy.deepcopy(batch.meta_info["generation_config"])
                            stage2_config["num_return_sequences"] = 1
                            stage2_prompt_data.meta_info["is_offload_states"] = False
                            stage2_prompt_data.meta_info['generation_config'] = stage2_config
                            stage2_prompt_data.meta_info["skip_reward"] = False
                            stage2_batch_ref = scheduler.get_batch.remote(data=stage2_prompt_data, batch_size=self.domain_batch_size[domain]*self.pipeline_config.k1)
                            stage2_batch: DataProto = ray.get(stage2_batch_ref)
                            
                            stage1_data = copy.deepcopy(stage1_output) 
                            batch_size = stage1_data.batch.batch_size[0]
                            for group_start in range(0, batch_size, 8):
                                group_end = min(group_start + 8, batch_size)
                                group_size = group_end - group_start
                                
                                s1_scores = stage1_data.batch["scores"][group_start:group_end]
                                s2_scores = stage2_batch.batch["scores"][group_start:group_end]
                                
                                s1_mean = s1_scores.mean()
                                s1_std = s1_scores.std(unbiased=False) + 1e-5
                                s1_normalized = (s1_scores - s1_mean) / s1_std
                                
                                s2_mean = s2_scores.mean()
                                s2_std = s2_scores.std(unbiased=False) + 1e-5
                                s2_normalized = (s2_scores - s2_mean) / s2_std
                                stage1_data.batch["scores"][group_start:group_end] = 0.5*s1_normalized + 0.5*s2_normalized
                                stage1_data.batch["response_level_rewards"][group_start:group_end] = 0.5*s1_normalized + 0.5*s2_normalized
                                
                                for i in range(group_start, group_end):
                                    stage1_data.batch["token_level_rewards"][i] = torch.zeros_like(
                                        stage1_data.batch["responses"][i], dtype=torch.float16
                                    )
                            tt1 = stage1_data.batch['scores']
                            tt2 = stage1_data.batch["response_level_rewards"]

                            if not self.pipeline_config.if_share_reward:
                                if not self.pipeline_config.if_mid_reward:
                                    chosed_mode = 0
                                else:
                                    chosed_mode = self.pipeline_config.stage1_sample_mode
                                pre_stage3_prompt_data = []
                                for i in range(self.domain_batch_size[domain]):
                                    intermediate_data = stage1_output[i * self.pipeline_config.k1 : (i+1) * self.pipeline_config.k1]
                                    if chosed_mode == 0:
                                        pre_stage3_prompt_data.append(intermediate_data[[torch.randint(0, self.pipeline_config.k1, (1,)).item()]])
                                    elif chosed_mode == 1:
                                        for j in range(self.pipeline_config.k1):
                                            kk = intermediate_data.batch["chosed_index"][j]
                                            if j+1 == int(intermediate_data.batch["chosed_index"][j]):
                                                pre_stage3_prompt_data.append(intermediate_data[[j]])
                                stage3_input = DataProto.concat(pre_stage3_prompt_data)
                                stage3_prompt_data = self.build_stage2_prompt(stage3_input)
                                stage3_config = copy.deepcopy(batch.meta_info["generation_config"])
                                stage3_config["num_return_sequences"] = self.pipeline_config.k1
                                stage3_prompt_data.meta_info["is_offload_states"] = False
                                stage3_prompt_data.meta_info['generation_config'] = stage3_config
                                stage3_prompt_data.meta_info["skip_reward"] = False
                                stage3_batch_ref = scheduler.get_batch.remote(data=stage3_prompt_data, batch_size=self.domain_batch_size[domain])
                                stage3_batch: DataProto = ray.get(stage3_batch_ref)
                            else:
                                stage3_batch: DataProto = DataProto()

                            stage1_data = copy.deepcopy(stage1_output)
                            stage2_data_list = []  
                            

                            if "token_level_rewards" not in stage1_data.batch.keys():
                                zero_tokens = torch.zeros(
                                    stage1_data.batch.batch_size[0], stage1_data.batch["responses"].size(1), dtype=torch.float16, device=stage1_data.batch.device
                                )
                                stage1_data.batch.update({"token_level_rewards": zero_tokens})

                            if 'chosed_index' in stage1_data.batch:
                                del stage1_data.batch['chosed_index']
                                del stage1_data.batch['mode']
                            if not self.pipeline_config.if_share_reward:
                                domain_batch = DataProto.concat([stage1_data, stage3_batch])
                                # domain_batch = stage3_batch
                            else: #self.pipeline_config.if_share_reward is True
                                stage2_batch.meta_info['generation_config']["num_return_sequences"] = self.pipeline_config.k1
                                domain_batch = DataProto.concat([stage1_data, stage2_batch])

                            domain_batches[domain] = domain_batch

                    else:
                        for domain, scheduler in self.generate_schedulers.items():
                            scheduler_refs[domain] = scheduler.get_batch.remote(data=batch, batch_size=self.domain_batch_size[domain])
                        for domain, scheduler_ref in scheduler_refs.items():
                            domain_batch: DataProto = ray.get(scheduler_ref, timeout=self.pipeline_config.rpc_timeout)
                            metrics_mgr.add_domain_metrics(
                                domain, reduce_metrics(domain_batch.meta_info.pop("metrics", {}))
                            )
                            domain_batches[domain] = domain_batch
    
                    generate_output = DataProto.concat([domain_batch for domain_batch in domain_batches.values()])
                    generate_output.meta_info.pop("is_offload_states", None) 

                    for reward_cluster in self.rewards.values(): # 释放奖励
                        reward_cluster.offload_states()
                    
                    gen_metrics = self.actor_infer.stop_server()
                    metrics_mgr.add_domain_metrics(domain, reduce_metrics(gen_metrics.meta_info.pop("metrics", {})))
                metrics_mgr.add_metric("time/step_generate", step_generate_timer.last)

                batch = generate_output

                with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: 
                    ref_log_probs = self.reference.compute_log_probs(batch, blocking=True)
                    metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {}))
                    ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs")
                    batch = batch.union(ref_log_probs)
                metrics_mgr.add_metric("time/ref_log_probs_values", cal_ref_log_probs_timer.last)

                with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer:
                    batch.meta_info["is_offload_states"] = False
                    if self.pipeline_config.adv_estimator == "gae":
                        values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False)
                    old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False)
                    old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs)
                    agg_entropy = agg_loss(
                        loss_mat=old_log_probs.batch["entropy"],
                        loss_mask=batch.batch["response_mask"][:, 1:],
                        loss_agg_mode="token-mean",
                    )
                    batch.meta_info["agg_entropy"] = agg_entropy

                    if self.pipeline_config.adv_estimator == "gae":
                        values = DataProto.materialize_concat(data_refs=values_refs)
                        batch = batch.union(values)
                        metrics_mgr.add_reduced_metrics(values.meta_info.pop("metrics", {}))

                    batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"]
                    metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {}))
                metrics_mgr.add_metric("time/old_log_probs", cal_old_logpb_timer.last)

                # 要按domain group by处理reward
                batch.batch["prompt_id"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device)
                batch_grouped: Dict[str, DataProto] = batch.group_by("domain")
                batch_list = []
                for domain, domain_batch in batch_grouped.items():
                    # 1. 处理mask相关策略， 获取sample level mask
                    with Timer(name="get_sample_level_mask", logger=None) as get_sample_level_mask_timer:
                        domain_batch, mask_metrics = get_sample_level_mask(domain_batch, self.pipeline_config)
                        metrics_mgr.add_metrics(mask_metrics)
                    metrics_mgr.add_metric("time/get_sample_level_mask", get_sample_level_mask_timer.last)

                    # 2. 处理reward相关策略
                    with Timer(name="reward_postprocess", logger=None) as reward_postprocess_timer:
                        domain_batch, response_level_metrics = reward_postprocess(
                            domain_batch, self.pipeline_config, self.running
                        )
                        metrics_mgr.add_metrics(response_level_metrics)
                    metrics_mgr.add_metric("time/reward_postprocess", reward_postprocess_timer.last)

                    # 3. 计算token level rewards
                    with Timer(name="get_token_reward", logger=None) as get_token_reward_timer:
                        domain_batch, token_level_metrics = compute_token_reward(
                            domain_batch, self.pipeline_config, self.kl_ctrl
                        )
                        metrics_mgr.add_metrics(token_level_metrics)
                    metrics_mgr.add_metric("time/get_token_reward", get_token_reward_timer.last)

                    # 4. 计算advantage
                    final_response_mask = domain_batch.batch["final_response_mask"].clone()
                    with Timer(name="compute_advantage", logger=None) as compute_advantage_timer:
                        domain_batch = compute_advantage(
                            data=domain_batch,
                            gamma=self.pipeline_config.gamma,
                            lambd=self.pipeline_config.lambd,
                            adv_estimator=self.pipeline_config.adv_estimator,
                            advantage_clip=self.pipeline_config.advantage_clip,
                            whiten_advantages=self.pipeline_config.whiten_advantages,
                            whiten_rewards=self.pipeline_config.whiten_rewards,
                            response_mask=final_response_mask,
                        )
                        domain_metrics = reduce_metrics(domain_batch.meta_info.pop("metrics", {}))
                        metrics_mgr.add_domain_metrics(domain, domain_metrics)
                        batch_list.append(domain_batch)

                batch = DataProto.concat(batch_list)
                batch.reorder(indices=torch.argsort(batch.batch["prompt_id"]))
                batch.pop("prompt_id")
                logger.info(f"总长度为{len(batch)}")
                metrics_mgr.add_all_metrics(
                    global_step,
                    batch,
                    resource_manager=self.resource_manager,
                    actor_infer=self.actor_infer,
                    actor_train=self.actor_train,
                )
                batch_grouped: Dict[str, DataProto] = batch.group_by("domain")
                metrics_mgr.add_domain_all_metrics(global_step, batch_grouped)

                with Timer(name="step_train", logger=None) as step_train_timer:
                    if self.pipeline_config.adv_estimator == "gae":
                        critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)

                    with actor_train_timer:
                        # implement critic warmup
                        if self.pipeline_config.critic_warmup <= global_step:
                            # update actor
                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(
                                data_refs=actor_train_metrics_refs
                            )
                            metrics_mgr.add_reduced_metrics(actor_train_metrics.meta_info.pop("metrics", {}))

                    if self.pipeline_config.adv_estimator == "gae":
                        critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                        metrics_mgr.add_reduced_metrics(critic_train_metrics.meta_info.pop("metrics", {}))

                metrics_mgr.add_metric("time/step_train", step_train_timer.last)

                tps_timer.push_units_processed(n=torch.sum(batch.batch["attention_mask"]).detach().item())
                actor_infer_timer.push_units_processed(n=torch.sum(batch.batch["attention_mask"]).detach().item())
                actor_infer_response_timer.push_units_processed(
                    n=torch.sum(batch.batch["response_mask"]).detach().item()
                )
                actor_train_timer.push_units_processed(n=torch.sum(batch.batch["attention_mask"]).detach().item())

                metrics = metrics_mgr.get_metrics()
                # do ckpt
                self.state.step = global_step
                self.state.log_history.append(metrics)
                for domain, scheduler in self.generate_schedulers.items():
                    self.state.kv[f"scheduler_state_{domain}"] = ray.get(scheduler.get_scheduler_state.remote())

                self.do_checkpoint(global_step=global_step)

                self.tracker.log(values=metrics, step=global_step)

                if global_step % self.pipeline_config.logging_steps == 0:
                    if int(os.environ.get("RAY_PROFILING", "0")):
                        timeline_dir = os.path.join(self.pipeline_config.profiler_output_dir, "timeline")
                        os.makedirs(timeline_dir, exist_ok=True)
                        ray.timeline(
                            filename=os.path.join(timeline_dir, f"timeline-step-{global_step}.json"),
                        )

                    prompts = self.tokenizer.batch_decode(generate_output.batch["prompts"], skip_special_tokens=True)
                    input_ids = self.tokenizer.batch_decode(generate_output.batch["input_ids"], skip_special_tokens=True)
                    responses = self.tokenizer.batch_decode(
                        generate_output.batch["responses"], skip_special_tokens=True
                    )
                    generate_examples = [{"prompt": p, "response": r} for p, r in zip(prompts, responses)][:10]
                    input_examples = [{"input_ids": p, "response": r} for p, r in zip(input_ids, responses)][:10]
                    logger.info(json.dumps(generate_examples, ensure_ascii=False))
                    logger.info(json.dumps(input_examples, ensure_ascii=False))
                    logger.info(json.dumps(metrics, ensure_ascii=False))

                logger.info(f"pipeline step {global_step} finished")
                global_step += 1
        logger.info("pipeline complete!")

    @torch.no_grad()
    def val(self, global_step):
        val_metrics_mgr = MetricsManager()
        batch = DataProto()

        with Timer(name="step_generate", logger=None) as step_generate_timer:
            batch.meta_info["is_offload_states"] = False
            batch.meta_info["generation_config"] = self.pipeline_config.validation.generating_args.to_dict()

            self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info))
            for reward_cluster in self.rewards.values():
                reward_cluster.load_states()
            if self.pipeline_config.if_2_stage:

                stage1_config = copy.deepcopy(batch.meta_info["generation_config"])
                stage1_config["num_return_sequences"] = 1  
                batch.meta_info["skip_reward"] = True
                batch.meta_info['generation_config'] = stage1_config
                stage1_batch_ref = self.val_generate_scheduler.get_batch.remote(data=batch, batch_size=len(self.val_dataset))
                stage1_output: DataProto = ray.get(stage1_batch_ref)
                batch.meta_info["skip_reward"] = False  # 恢复
                stage2_prompt_data = self.build_stage2_prompt(stage1_output)

                stage2_config = copy.deepcopy(batch.meta_info["generation_config"])
                stage2_config["num_return_sequences"] = 1 
                stage2_prompt_data.meta_info['generation_config'] = stage2_config
                stage2_batch_ref = self.val_generate_scheduler.get_batch.remote(data=stage2_prompt_data, batch_size=len(self.val_dataset))
                generate_output = ray.get(stage2_batch_ref)
                logger.info(f"val result:\n{len(generate_output)}")
            else:
                generate_output: DataProto = ray.get(
                    self.val_generate_scheduler.get_batch.remote(data=batch, batch_size=len(self.val_dataset)),
                    timeout=self.pipeline_config.rpc_timeout
                )

            self.actor_infer.stop_server()
            generate_output.meta_info.pop("is_offload_states", None)
            for reward_cluster in self.rewards.values():
                reward_cluster.offload_states()
            val_metrics_mgr.add_metric("time/step_generate", step_generate_timer.last)

        batch = generate_output
        val_correct_mean = (batch.batch["scores"] == 1).detach().float().mean().item()
        val_metrics_mgr.add_metric("val_correct/all/mean", val_correct_mean)
        logger.info(json.dumps({"val_correct/all/mean": val_correct_mean}, ensure_ascii=False))

        # Save validation results
        output_dir = self.pipeline_config.validation.output_dir
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            output_file = os.path.join(output_dir, f"validation_results_{global_step}.jsonl")

            save_ds = datasets.Dataset.from_dict(generate_output.non_tensor_batch)
            if self.pipeline_config.validation.save_columns:
                try:
                    save_ds = save_ds.select_columns(self.pipeline_config.validation.save_columns)
                except Exception as e:
                    logger.error(f"Error selecting columns: {e}, default save all columns")
            
            responses = self.tokenizer.batch_decode(
                generate_output.batch["responses"], skip_special_tokens=True
            )
            save_ds = save_ds.add_column("response", responses)
            if self.pipeline_config.validation.save_input_seq:
                input_seqs = self.tokenizer.batch_decode(generate_output.batch["input_ids"], skip_special_tokens=True)
                save_ds = save_ds.add_column("input_seq", input_seqs)

            save_ds.to_json(output_file, lines=True, force_ascii=False)
            logger.info(f"Validation results saved to {output_file}")

        epoch_batch = batch.pop(batch_keys=["scores"], non_tensor_batch_keys=["tag"])

        grouped_batch = epoch_batch.group_by("tag")
        for group_key, group_batch in grouped_batch.items():
            score_mean = group_batch.batch["scores"].mean().item()
            print(f"{group_key}:  {score_mean}")
            val_metrics_mgr.add_domain_metrics(
                "val_correct", {f"{group_key}/mean": (group_batch.batch["scores"] == 1).detach().float().mean().item()}
            )

        return val_metrics_mgr.get_metrics()


    def build_stage2_prompt(self, stage1_data: DataProto):
        stage2_prompts = []
        for i in range(stage1_data.batch.batch_size[0]):
            output1 = self.tokenizer.decode(stage1_data.batch["responses"][i]) 
            input1 = self.tokenizer.decode(stage1_data.batch["prompts"][i])
            try:
                tmpdict = process_stage1_output(extract_first_to_last_brace(output1))
                policy = tmpdict["strategy"]
                mode = tmpdict["mode"]
            except json.JSONDecodeError as e:
                policy = "None"
                mode = 'goal-oriented'
            except KeyError as e:
                policy = "None"
                mode = 'goal-oriented'

            stage2_prompt = input1.split("The available action types")[0]+ prompt_stage2.replace("<strategy>",str(policy)).replace("<mode>",str(mode)) + "\n" 
            stage2_prompts.append(stage2_prompt)
        encodings = self.tokenizer(
            stage2_prompts,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.pipeline_config.prompt_length
        )
        input_ids = encodings["input_ids"]  # (batch_size, seq_len)
        attention_mask = encodings["attention_mask"]  # (batch_size, seq_len)

        position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).unsqueeze(0).expand(input_ids.shape[0], -1)  # (batch_size, seq_len)

        prompt_mask = attention_mask.bool()

        prompts_list = [ids[ids != self.tokenizer.pad_token_id] for ids in input_ids]  # 移除padding
        prompts = pad_sequence(prompts_list, batch_first=True, padding_value=self.tokenizer.pad_token_id) 

        batch_data = TensorDict(
            {
                "input_ids": input_ids,  # (batch_size, seq_len)
                "attention_mask": attention_mask,  # (batch_size, seq_len)
                "position_ids": position_ids,  # (batch_size, seq_len)
                "prompt_mask": prompt_mask,  # (batch_size, seq_len)
                "prompts": prompts,  # (batch_size, max_prompt_len)
            },
            batch_size=len(stage2_prompts)
        )
        
        return DataProto(
            batch=batch_data,
            non_tensor_batch=stage1_data.non_tensor_batch,  
            meta_info=stage1_data.meta_info
        )
