import asyncio
import json
import math
import os
import random
from functools import partial
from heapq import heapify, heappop, heappush
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
from collections import defaultdict

import wandb
import ray
import torch
import sys
from loguru import logger
from omegaconf.dictconfig import DictConfig
from ray.util.placement_group import PlacementGroup, placement_group
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from thinker_task.ppo.actors import PPORayActorGroup
from thinker_task.ppo.replay_buffer import Experience, NaiveReplayBuffer, BufferItem
from thinker_task.ppo.utils import ORZDeepspeedStrategy as DeepspeedStrategy
from thinker_task.ppo.utils import (
    Timer,
    create_vllm_engines,
    compute_approx_kl,
    compute_reward,
    get_advantages_and_returns,
    masked_mean,
    normalize_advantages,
    packed_create_token_mask,
    save_debug_data,
    cum_clip,
)
from thinker_task.exp_engine.accelerators.inference.utils import encode_prompts

class RayPPOTrainer:
    def __init__(
        self,
        cfg: DictConfig,
        strategy: DeepspeedStrategy,
        tokenizer,
        train_dataset,
        eval_dataset=None,
        colocate_pg: Optional[PlacementGroup] = None,
    ):
        self.cfg = cfg
        self.strategy = strategy
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.prompts_dataloader = self.build_dataloader(train_dataset)
        self.colocate_pg = colocate_pg

        self.writer = SummaryWriter(log_dir=self.cfg.tensorboard_log_dir)
        self._wandb = None        

        logger.remove()
        logger.add(
            sys.stderr,
            format=f"\033[32m{self.cfg.run_name}\033[0m: {{time:YYYY-MM-DD HH:mm:ss}} | {{level}} | {{message}}",
            level="INFO"
        )

        if self.cfg.use_wandb:
            self._wandb = wandb
            if not wandb.api.api_key:                
                wandb.login(key=self.cfg.wandb_api_key)     
            logger.info("Initializing wandb")
            self.wandb_run = wandb.init(
                entity=self.cfg.wandb_entity,
                project=self.cfg.wandb_project,
                group=self.cfg.wandb_group,
                name=self.cfg.wandb_run_name,
                id=self.cfg.wandb_run_name,
                config=self.cfg.__dict__,
                resume="allow",
                allow_val_change=True,
                settings=wandb.Settings(init_timeout=300),
            )
            logger.info("Initialized wandb")
            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/*", step_metric="train/global_step", step_sync=True)

            self.wandb_columns = ["global_step", "prompt", "response", "final_answer", "answer_status", "stop_reason", "score"]
            if self.cfg.multi_attempt:
                self.wandb_columns.extend(["attempt_used"])
            if self.cfg.summary:
                self.wandb_columns.extend(["response_status"])                
                if self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    self.wandb_columns.extend(["consist_return"])
            try:
                self.wandb_table = self.wandb_run.use_artifact("run-{self.wandb_run.id}-text_example:latest").get("text_example")
                self.wandb_table = wandb.Table(columns=self.wandb_columns, data=self.wandb_table.data)                
                logger.info("Loaded previous wandb table")
            except Exception as e:                
                self.wandb_table = wandb.Table(columns=self.wandb_columns)
                logger.info("Created new wandb table")      

        self.replay_buffer = NaiveReplayBuffer(
            sample_batch_size=self.cfg.micro_train_batch_size,
            limit=0,
            cpu_offload=True,
            packing_samples=True,
        )                

        if self.cfg.summary:
            self.summary_buffer = NaiveReplayBuffer(
                sample_batch_size=self.cfg.micro_train_batch_size,
                limit=self.cfg.summary_buffer_size,
                cpu_offload=True,
                packing_samples=True,
            )    

            loaded_count = 0
            if self.cfg.load_checkpoint:
                summary_buffer_path = os.path.join(self.cfg.save_path, "summary_buffer.jsonl")                
                if os.path.exists(summary_buffer_path):
                    logger.info(f"Loading summary buffer from {summary_buffer_path}")                    
                    with open(summary_buffer_path, "r") as f:
                        for line in f:
                            try:
                                item_data = json.loads(line.strip())  # Load JSON line
                                item_data = {k:torch.tensor(v) if isinstance(v, List) else v for k, v in item_data.items()}
                                buffer_item = BufferItem(**item_data)  # Convert JSON dict to BufferItem
                                self.summary_buffer.items.append(buffer_item)  # Append to replay buffer
                                loaded_count += 1
                            except Exception as e:
                                logger.error(f"Failed to load item from summary buffer: {e}")
                    logger.info(f"Successfully loaded {loaded_count} items into the summary buffer")

    def __del__(self):
        if self.writer is not None: self.writer.close()
        if self._wandb is not None: self._wandb.finish()

    async def eval(self):
        raise NotImplementedError("Eval function should be implemented in user's exp")
    
    async def create_inference_engine(self):

        if self.cfg.colocate_all:
            colacate_pg = self.colocate_pg
        else:
            colacate_pg = None  

        kwargs = dict(
            num_engines=self.cfg.vllm_num_engines,
            tensor_parallel_size=self.cfg.vllm_tensor_parallel_size,
            pretrain=self.cfg.pretrain,
            seed=self.cfg.seed,
            summary=self.cfg.summary,
            multi_attempt=self.cfg.multi_attempt,
            enable_prefix_caching=self.cfg.enable_prefix_caching,
            enforce_eager=self.cfg.enforce_eager,
            max_model_len=self.cfg.max_len,
            colocate_with_actor=self.cfg.colocate_all,
            enable_chunked_prefill=self.cfg.enable_chunked_prefill,
            max_num_batched_tokens=self.cfg.max_num_batched_tokens,
            gpu_memory_utilization=self.cfg.gpu_memory_utilization,
            max_num_seqs=self.cfg.micro_rollout_batch_size,
            colocate_pg=colacate_pg,
        )

        if self.cfg.multi_attempt:
            kwargs_ = dict(
                min_attempt = self.cfg.min_attempt,
                max_attempt = self.cfg.max_attempt,
                repeat_question = self.cfg.repeat_question,
                prompt_type = self.cfg.prompt_type,
            )
            kwargs.update(kwargs_)

        if self.cfg.summary:
            kwargs_ = dict(
                prompt_type = self.cfg.prompt_type,
                summary_min_token = self.cfg.summary_min_token,
                summary_max_token = self.cfg.summary_max_token,
                verify_max_token = self.cfg.verify_max_token,
                slow_max_token = self.cfg.slow_max_token,
                reward_right_format = self.cfg.reward_right_format,
                summary_temperature = self.cfg.summary_temperature,
                summary_skip = self.cfg.summary_skip,
                verify_skip = self.cfg.verify_skip,
                summary_reward_coef = self.cfg.summary_reward_coef,
                fast_reward_coef = self.cfg.fast_reward_coef,
                summary_nonstop_discount = self.cfg.summary_nonstop_discount,
            )
            kwargs.update(kwargs_)

        return create_vllm_engines(**kwargs)
    
    async def init_vllm_engines(self):
        self.vllm_engines = await self.create_inference_engine()

    async def train(self):
        # 1. create rank0 policy model and vllm_engines groups, then boardcast weights to vllm engins
        logger.info("Start training")
        if self.cfg.colocate_all:
            logger.info("Moving actor to GPU")
            await self.policy_model.backload_to_gpu()
            logger.info("Moving vllm to CPU")
        
        await self.init_vllm_engines()

        if self.cfg.colocate_all:
            await self._backload_vllm_engines()

        logger.info("Creating vllm engine gourps")
        await self.policy_model.async_run_method("_init_vllm_engines_actor_group", self.vllm_engines)
        logger.info("Created vllm engine gourps done")

        async with Timer("Sync actor weights to vllm engines"):
            await self._sync_policy_weights_to_vllm()

        if self.cfg.colocate_all:
            async with Timer("Offload policy model to cpu"):
                await self.policy_model.offload_to_cpu()        

        # 2. main training loop
        consumed_samples = 0
        best_eval_score = float("-inf")
        self.summary_step = 0

        if self.cfg.load_checkpoint:
            client_state = await self.policy_model.get_client_state()
            consumed_samples = client_state["consumed_samples"]
            best_eval_score = client_state["best_eval_score"]
            self.summary_step = client_state["summary_step"]

        num_rollouts_per_episodes = (
            self.num_update_steps_per_episodes
            * self.cfg.train_batch_size
            // self.cfg.max_epochs
            // self.cfg.rollout_batch_size
            // self.cfg.n_samples_per_prompt
        )

        if self.cfg.resume_global_step > 0:
            consumed_samples = self.cfg.resume_global_step * self.cfg.rollout_batch_size

        self.global_step = consumed_samples // self.cfg.rollout_batch_size
        start_episode = consumed_samples // self.cfg.rollout_batch_size // num_rollouts_per_episodes
        consumed_samples = consumed_samples % (num_rollouts_per_episodes * self.cfg.rollout_batch_size)

        await self.policy_model.async_set_tokenizer(self.tokenizer)

        for episode in range(start_episode, self.cfg.num_episodes):
            pbar = tqdm(
                range(self.prompts_dataloader.__len__()), desc=f"Episode [{episode + 1}/{self.cfg.num_episodes}]"
            )
            for iter, rand_prompts in enumerate(self.prompts_dataloader):
                logger.info(f"\033[32m Global step: {self.global_step} \033[0m")

                status_train = {}

                # 1. eval if enable eval
                if self.cfg.enable_eval and (
                    self.global_step % self.cfg.eval_interval == 0 #or iter == len(self.prompts_dataloader) - 1
                ):                          
                    status_eval = await self.eval()       

                    if status_eval["avg_acc"] > best_eval_score:
                        best_eval_score = status_eval["avg_acc"]
                        await self.policy_model.async_save_model(self.tokenizer, tag="best")
                        if self.critic_model is not None:
                            await self.critic_model.async_save_model(self.tokenizer, tag="best")
                    logger.info("Successfully save best model weights, training continue")     
                else:
                    status_eval = None

                # 3. make experiences, calculate advantages and returns
                status = await self.make_experience(rand_prompts)
                status_train.update(status)

                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    status_train["table"]["consist_return"] = status_train.pop("first_consist_return")

                # check if has enough data
                if len(self.replay_buffer) <= 0:
                    if self.cfg.colocate_all:
                        # skip, but transfer weight
                        await self.policy_model.backload_to_gpu()
                        await self._backload_vllm_engines()
                        await self._sync_policy_weights_to_vllm()
                        await self.policy_model.offload_to_cpu()
                    continue

                if self.cfg.advantage_normalize:
                    self.replay_buffer = normalize_advantages(self.replay_buffer)

                num_policy_dp_nodes = self.cfg.actor_num_nodes * self.cfg.actor_num_gpus_per_node
                num_critic_dp_nodes = self.cfg.critic_num_nodes * self.cfg.critic_num_gpus_per_node
                policy_buffers = self.replay_buffer.split_to_n_batches(num_policy_dp_nodes)
                if num_policy_dp_nodes != num_critic_dp_nodes:
                    critic_buffers = self.replay_buffer.split_to_n_batches(num_critic_dp_nodes)
                else:
                    critic_buffers = policy_buffers

                summary_buffers = None
                if self.cfg.summary:
                    if len(self.summary_buffer) >= self.summary_buffer.limit // 2:             
                        self.summary_step += 1                        
                        summary_buffers = self.summary_buffer.split_to_n_batches(num_policy_dp_nodes)                        

                # 4. train policy/critic model
                status_critic = {}
                if self.cfg.colocate_all:
                    if self.critic_model is not None:
                        async with Timer("Critic model training"):
                            await self.critic_model.backload_to_gpu()
                            status_critic = await self.ppo_local_train_critic(critic_buffers, self.global_step)
                            await self.critic_model.offload_to_cpu()
                    async with Timer("Actor model training"):
                        await self.policy_model.backload_to_gpu()
                        status_policy = await self.ppo_local_train_policy(policy_buffers, summary_buffers, self.global_step)
                        await self.policy_model.offload_to_cpu()

                else:
                    if self.ref_model is not None:
                        await self.ref_model.async_run_method("empty_cache")         
                    if self.critic_model is not None:
                        async with Timer("Actor and Critic model training"):
                            status_policy, status_critic = await asyncio.gather(
                                self.ppo_local_train_policy(policy_buffers, summary_buffers, self.global_step),
                                self.ppo_local_train_critic(critic_buffers, self.global_step),
                            )
                            await asyncio.gather(
                                self.policy_model.async_run_method("empty_cache"),
                                self.critic_model.async_run_method("empty_cache"),
                            )
                    else:
                        async with Timer("Actor model training"):
                            status_policy = await self.ppo_local_train_policy(policy_buffers, summary_buffers, self.global_step)
                            await self.policy_model.async_run_method("empty_cache")
                status_train.update(status_policy)
                status_train.update(status_critic)
                self.replay_buffer.clear()

                # 5. set logs
                logger.info(f"Generation status: {status}")
                pbar.update()
                # log epoch info
                self.writer.add_scalar("episode_idx", episode, self.global_step)

                if self._wandb is not None:
                    logs = {}
                    if "table" in status_train:
                        table = status_train.pop("table")
                        table["global_step"] = self.global_step
                        new_row = [table[s] for s in self.wandb_columns]
                        self.wandb_table.add_data(*new_row)

                    logs.update({
                        "train/%s" % k: v
                        for k, v in {
                            **status_train,
                            "global_step": self.global_step,
                        }.items()
                    })
                    if status_eval is not None:
                        logs.update({f"eval/{k}": v for k, v in status_eval.items()})                    
                    self._wandb.log(logs)

                    if self.global_step % 10 == 0:                        
                        self.wandb_table = wandb.Table(columns=self.wandb_columns, data=self.wandb_table.data)
                        self._wandb.log({"text_example": self.wandb_table})

                self.global_step += 1
                tag = "global_step_%d" % self.global_step                
                if self.global_step % self.cfg.save_interval == 0:
                    await self.policy_model.async_save_model(self.tokenizer, tag=tag)
                    if self.critic_model is not None:
                        await self.critic_model.async_save_model(self.tokenizer, tag=tag)
                    logger.info("Successfully save model weights, training continue")        

                if self.global_step % self.cfg.ckpt_interval == 0:
                    client_state = {"consumed_samples": self.global_step * self.cfg.rollout_batch_size,
                                    "best_eval_score": best_eval_score,
                                    "summary_step": self.summary_step,
                                    }                    
                    await self.policy_model.async_save_ckpt(tag=tag, client_state=client_state)
                    if self.critic_model is not None:
                        await self.critic_model.async_save_ckpt(tag=tag, client_state=client_state)
                    logger.info("Successfully save model checkpoints, training continue")        


            if self.cfg.update_ref_every_epoch and self.ref_model is not None:
                tag = "global_step_%d" % self.global_step
                await self.policy_model.backload_to_gpu()
                await self.policy_model.async_save_model(self.tokenizer, tag=tag)
                await self.policy_model.offload_to_cpu()
                await asyncio.gather(
                    *self.ref_model.async_init_model_from_pretrained(
                        self.strategy, os.path.join(self.cfg.save_path, "_actor_hf", tag)
                    )
                )
                logger.info("Successfully update ref model with policy model, training continue")

        await self.policy_model.async_save_model(self.tokenizer, tag="final")
        logger.info("Successfully save model weights, training done")

    @torch.no_grad()
    async def make_experience(self, all_inputs: Union[Tuple[str, dict], List[Tuple[str, dict]]], **generate_kwargs):
        experiences = []
        all_prompts = sum([[prompt[0]] * self.cfg.n_samples_per_prompt for prompt in all_inputs], [])
        all_extras = sum([[prompt[1]] * self.cfg.n_samples_per_prompt for prompt in all_inputs], [])
        # shuffle all_prompts and all_extras together
        indices = list(range(len(all_prompts)))
        rng = random.Random(42)
        rng.shuffle(indices)
        all_prompts = [all_prompts[i] for i in indices]
        all_extras = [all_extras[i] for i in indices]

        # 1. generate sequences and inference, calculate values, log probs, rewards, kl divergence
        # 1.1 generate sequences via vllm engines
        outputs = []
        num_vllm_dp_gruops = len(self.vllm_engines)
        status_train = {}

        async with Timer("Generate sequences via vllm engines"):
            dp_prompt_size = (len(all_prompts) + num_vllm_dp_gruops - 1) // num_vllm_dp_gruops
            dp_tasks = []
            for dp_rank in range(num_vllm_dp_gruops):
                dp_inputs = all_prompts[dp_rank * dp_prompt_size : (dp_rank + 1) * dp_prompt_size]
                dp_extras = all_extras[dp_rank * dp_prompt_size : (dp_rank + 1) * dp_prompt_size]
                # handle last batch has no enough data
                if len(dp_inputs) <= 0:
                    continue
                gen_func = self._get_generate_function(dp_rank)
                dp_tasks.append(self.generate_vllm(gen_func, dp_inputs, extras=dp_extras, **generate_kwargs))

            logger.info(f"Start vllm generation")
            local_responses = await asyncio.gather(*dp_tasks)
            outputs.extend(sum(local_responses, []))
            logger.info(f"Finish vllm generation")

            # offload vllm engines when colocate all models
            if self.cfg.colocate_all:
                async with Timer("Offload vllm engines to cpu"):
                    await self._offload_vllm_engines()

        # skip when data is not enough
        if len(outputs) <= 0:
            return
        
        if self.cfg.multi_attempt or self.cfg.summary:
            all_prompts = [x["prompt"] for x in outputs]

        assert len(all_prompts) == len(outputs), "generate objects number must be equal to all inputs number"

        # 1.2 calculate custom rewards if has custom reward function
        async with Timer("Calculate custom rewards"):
            dp_tasks = []
            reward_fn = self.custom_reward_fn if not self.cfg.summary else self.summary_reward_fn
            all_prompts, outputs, custom_rewards, info, status = await reward_fn(all_prompts, outputs, all_extras)
            assert len(all_prompts) == len(
                outputs
            ), "generate objects number after custom reward function must be equal to all inputs number"
        status_train.update(status)

        # empty data
        if len(all_prompts) == 0: return

        # 1.3 packing samples
        async with Timer("Packing samples"):
            if self.cfg.summary:
                output_ids = info["response_ids"]
            else:
                output_ids = None
            (
                ret_sequences,
                ret_sequence_types,
                ret_attention_masks,                
                ret_num_actions,
                ret_packed_seq_lens,
                ret_custom_rewards,
            ) = self._convert_prompts_outputs_to_batch_tensors_packing(
                all_prompts, outputs, output_ids, custom_rewards, self.cfg.packing_max_len
            )

            action_masks = None
            dir_sum = None

            if self.cfg.summary:
                (
                sum_sequences,
                sum_sequences_types,
                sum_attention_masks,
                sum_num_actions,
                sum_packed_seq_lens,
                ) = self._filter_summary_to_packing(info, self.cfg.summary_packing_max_len)
                n = 0
                for i in range(len(sum_sequences)):
                    # if len(self.summary_buffer) >= self.cfg.summary_buffer_size: break
                    self.summary_buffer.append(
                        Experience(
                            sum_sequences[i],
                            None,
                            None,
                            None,
                            None,
                            None,
                            sum_attention_masks[i],
                            None,
                            torch.tensor(sum_num_actions[i]).unsqueeze(0),
                            torch.tensor(sum_packed_seq_lens[i]).unsqueeze(0),
                            {"global_step": torch.tensor(self.global_step).unsqueeze(0)},
                            None,
                            None,
                            None
                        )
                    )
                    n += len(sum_num_actions[i])
                summary_buffer_step = sum([x.info["global_step"] for x in self.summary_buffer.items]) / len(self.summary_buffer)
                logger.info(f"Summary buffer size: {len(self.summary_buffer)} with {n} new sequences and avg global step: {summary_buffer_step}")                
                
                if not self.cfg.summary_skip:
                    (   dir_sum_sequences,
                        dir_sum_sequnece_types,
                        dir_sum_attention_masks,
                        dir_sum_num_actions,
                        dir_sum_packed_seq_lens,
                    ) = self._filter_dir_summary_to_packing(info, self.cfg.summary_packing_max_len)
                    dir_sum = {
                        "sequences": dir_sum_sequences,
                        "dir_sum_sequnece_types": dir_sum_sequnece_types,
                        "attention_masks": dir_sum_attention_masks,
                        "num_actions": dir_sum_num_actions,
                        "packed_seq_lens": dir_sum_packed_seq_lens,
                    }

        # 1.4 inference and calculate values, log probs, rewards, kl divergence
        async with Timer("Inference and calculate values, log probs, rewards, kl divergence"):
            experiences = await self.inference_and_calculates(
                ret_sequences,
                ret_sequence_types,
                ret_attention_masks,
                action_masks,
                ret_num_actions,
                ret_packed_seq_lens,
                ret_custom_rewards,
                dir_sum,
            )
            logger.info(f"Experiences size: {len(experiences)}")

        # 2. visualization generated results example
        vis = self._detokenize(experiences[0].sequences[0][: int(experiences[0].info["total_length"].flatten()[0])])
        self.writer.add_text("generated_sequences", vis, self.global_step)
        self.writer.flush()

        # 3. calculate advantages and returns / along with tensorboard logging
        avg_rewards = 0
        avg_kl = 0
        avg_kl_max = 0
        avg_response_length = 0
        avg_orm_score = 0
        avg_custom_rewards = 0
        avg_advantages = 0
        avg_advantages_abs = 0

        consist_score_sum = 0
        consist_score_count = 0
        consist_score_max = float("-inf")
        consist_returns_sum = 0
        consist_returns_count = 0
        consist_returns_first = 0

        async with Timer("Calculate advantages and returns"):
            adv_tasks = []
            for experience in experiences:
                adv_tasks.append(self._calc_advantages_and_returns(experience))

            results = await asyncio.gather(*adv_tasks)            
            for n, (experience, metrics) in enumerate(results):
                avg_rewards += metrics["avg_rewards"]
                avg_kl += metrics["avg_kl"]
                avg_kl_max += metrics["avg_kl_max"]
                avg_response_length += metrics["avg_response_length"]
                avg_orm_score += metrics["avg_orm_score"]
                avg_custom_rewards += metrics["avg_custom_rewards"]
                avg_advantages += metrics["avg_advantages"]
                avg_advantages_abs += metrics["avg_advantages_abs"]
                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    consist_score_sum += metrics["consist_score_sum"]
                    consist_score_count += metrics["consist_score_count"]
                    consist_score_max = max(consist_score_max, metrics["consist_score_max"])
                    consist_returns_sum += metrics["consist_returns_sum"]
                    consist_returns_count += metrics["consist_returns_count"]
                    if n == 0:
                        consist_returns_first = metrics["consist_returns_first"]
                        #experience.save_data()
                
                self.replay_buffer.append(experience)

        # 4. tensorboard logging
        status = {
            "reward": avg_rewards / len(experiences),
            "kl": avg_kl / len(experiences),
            "kl_max": avg_kl_max / len(experiences),
            "response_length": avg_response_length / len(experiences),
            "orm_score": avg_orm_score / len(experiences),
            "custom_reward": avg_custom_rewards / len(experiences),
            "advantage": avg_advantages / len(experiences),
            "advantages_abs": avg_advantages_abs / len(experiences),
        }

        if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:            
            status["consist_score"] = (consist_score_sum / consist_score_count) if consist_score_count > 0 else 0.
            status["consist_score_max"] = consist_score_max
            status["consist_return"] = (consist_returns_sum / consist_returns_count) if consist_returns_count > 0 else 0.            
            status["first_consist_return"] = consist_returns_first
            status["summary_len"] = (consist_score_count / consist_returns_count) if consist_returns_count > 0 else 0.
            status["summary_buffer_size"] = len(self.summary_buffer)            
            status["summary_buffer_step"] = summary_buffer_step

        info_keys = ["reward", "kl", "response_length", "orm_score", "custom_reward"]
        logger.info(f'Experience status: {", ".join([f"{k}: {status[k]}" for k in info_keys])}')

        for k, v in status.items():
            self.writer.add_scalar(k, v, self.global_step)
        self.writer.flush()

        status_train.update(status)

        return status_train

    async def micro_infer_model(self, num_dps, model_type, sequences, num_actions, attention_mask, packed_seq_lens, **kwargs):
        dp_iterator = self._split_dp_batch(
            (sequences, num_actions, attention_mask, packed_seq_lens),
            num_dps,
        )
        dp_tasks = []
        for dp_rank, (
            micro_sequences,
            micro_num_actions,
            micro_attention_mask,
            micro_packed_seq_lens,
        ) in enumerate(dp_iterator):
            model = self._get_dp_group_models(dp_rank, model_type)

            async def forward_fn(
                local_model, fwd_sequences, fwd_num_actions, fwd_attention_mask, fwd_packed_seq_lens
            ):
                return await local_model.forward.remote(
                    sequences=fwd_sequences,
                    num_actions=fwd_num_actions,
                    attention_mask=fwd_attention_mask,
                    packed_seq_lens=fwd_packed_seq_lens,
                    **kwargs,
                )

            dp_tasks.append(
                self._split_and_run_micro_batch(
                    partial(forward_fn, model),
                    (micro_sequences, micro_num_actions, micro_attention_mask, micro_packed_seq_lens),
                    self.cfg.micro_forward_batch_size,
                )
            )
        results = await asyncio.gather(*dp_tasks)
        results = sum(results, [])
        return results

    @torch.no_grad()
    async def inference_and_calculates(
        self,
        sequences_all: List[torch.Tensor], # list of tensor with shape (1, *)
        sequence_types_all: List[torch.Tensor], # list of tensor with shape (1, *)
        attention_mask_all: List[torch.Tensor], # list of tensor with shape (1, *)
        action_mask_all: Optional[List[torch.Tensor]], # list of tensor with shape (1, *)
        num_actions_all: Optional[List[int]], # list of list of int
        packed_seq_lens_all: Optional[List[int]], # list of list of int
        custom_rewards_all: Optional[List[torch.Tensor]], # list of list of tensor with shape (*)
        dir_sum: Optional[dict] = None,
    ):
        num_policy_dp_groups = self.cfg.actor_num_nodes * self.cfg.actor_num_gpus_per_node
        num_critic_dp_groups = self.cfg.critic_num_nodes * self.cfg.critic_num_gpus_per_node
        num_ref_dp_groups = self.cfg.ref_num_nodes * self.cfg.ref_num_gpus_per_node

        if action_mask_all is not None:
            num_actions_all = action_mask_all.size(1)

        # calculate critic values
        if self.cfg.colocate_all and self.critic_model is not None:
            await self.critic_model.backload_to_gpu()

        if self.critic_model is not None:
            value_ref = self.micro_infer_model(
                num_critic_dp_groups,
                "critic_model",
                sequences_all,
                num_actions_all,
                attention_mask_all,
                packed_seq_lens_all,
            )
            values = None
            if self.cfg.colocate_all:
                values = await value_ref
                await self.critic_model.offload_to_cpu()

        # calculate ref log probs
        if self.ref_model is not None:
            base_action_log_probs_ref = self.micro_infer_model(
                num_ref_dp_groups, "ref_model", sequences_all, num_actions_all, attention_mask_all, packed_seq_lens_all
            )
        else:
            base_action_log_probs_ref = None
        base_log_probs = None

        # handle colocate critic and reward model
        if self.cfg.colocate_critic_reward and not self.cfg.colocate_all and self.critic_model is not None:
            values = await value_ref
            await self.critic_model.async_run_method("empty_cache")

        # handle colocate actor and ref model
        if self.cfg.colocate_actor_ref or self.cfg.colocate_all:
            if base_action_log_probs_ref is not None:
                base_log_probs = await base_action_log_probs_ref
            else:
                base_log_probs = None
            if self.ref_model is not None:
                await self.ref_model.async_run_method("empty_cache")

        # calculate action log probs
        if self.cfg.colocate_all:
            await self.policy_model.backload_to_gpu()

        action_log_probs_ref = self.micro_infer_model(
            num_policy_dp_groups,
            "policy_model",
            sequences_all,
            num_actions_all,
            attention_mask_all,
            packed_seq_lens_all,
        )
        action_log_probs = None

        if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
            summary_log_probs_ref = self.micro_infer_model(
                num_policy_dp_groups,
                "policy_model",
                dir_sum["sequences"],
                dir_sum["num_actions"],
                dir_sum["attention_masks"],
                dir_sum["packed_seq_lens"],
            )
            summary_log_probs = None

        if self.cfg.colocate_all:
            action_log_probs = await action_log_probs_ref
            if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                summary_log_probs_ = await summary_log_probs_ref                

            await self.policy_model.offload_to_cpu()     
            #save_debug_data(
            #    sequences_all=summary_sequences_all,
            #    num_actions_all=summary_num_actions_all,
            #    summary_attention_mask_all=summary_attention_mask_all,
            #    packed_seq_lens_all=summary_packed_seq_lens_all,
            #    summary_log_probs=summary_log_probs,
            #)   

        # wait all models done
        # if not colocate_actor_ref, then need to gather base_log_probs
        # if not colocate_critic_reward and self.critic_model is not None, then need to gather value
        # reward_refs is always handled at last
        if not self.cfg.colocate_all:
            if not self.cfg.colocate_actor_ref:                
                refs = [base_action_log_probs_ref, action_log_probs_ref]
                if not self.cfg.colocate_critic_reward and self.critic_model is not None:
                    refs.append(value_ref)
                    value_idx = len(refs) - 1
                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    refs.append(summary_log_probs_ref)
                    sum_idx = len(refs) - 1
                results = await asyncio.gather(*refs)
                base_log_probs, action_log_probs = results[0], results[1]

                if not self.cfg.colocate_critic_reward and self.critic_model is not None:
                    values = results[value_idx]
                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    summary_log_probs_ = results[sum_idx]
            else:
                refs = [action_log_probs_ref]
                if not self.cfg.colocate_critic_reward and self.critic_model is not None:
                    refs.append(value_ref)
                    value_idx = len(refs) - 1
                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    refs.append(summary_log_probs_ref)
                    sum_idx = len(refs) - 1
                results = await asyncio.gather(*refs)
                action_log_probs = results[0]
                if not self.cfg.colocate_critic_reward and self.critic_model is not None:
                    values = results[value_idx]
                if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                    summary_log_probs_ = results[sum_idx]

        if not self.cfg.colocate_all:
            empty_cache_tasks = [
                self.policy_model.async_run_method("empty_cache"),
            ]
            if self.critic_model:
                empty_cache_tasks.append(self.critic_model.async_run_method("empty_cache"))
            if self.reward_model:
                empty_cache_tasks.extend([rm.async_run_method("empty_cache") for rm in self.reward_model])
            if self.ref_model:
                empty_cache_tasks.append(self.ref_model.async_run_method("empty_cache"))
            await asyncio.gather(*empty_cache_tasks)

        if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
            summary_log_probs = summary_log_probs_[:len(dir_sum["sequences"])]
            summary_log_probs = torch.cat(summary_log_probs, dim=1)
            summary_num_actions_all = sum(dir_sum["num_actions"], []) # list of int

            summary_action_mask_all = [seq_type == 7 for seq_type in sequence_types_all]                
            sum_summary_action = torch.sum(torch.cat(summary_action_mask_all, dim=1)).item()
            # ray.logger.info(f"Summary action mask sum: {sum_summary_action} vs {summary_log_probs.shape[1]}")
            if sum_summary_action != summary_log_probs.shape[1]:
                save_debug_data(
                    prefix="dir_sum",
                    sequences_all=sequences_all,
                    num_actions_all=num_actions_all,
                    attention_mask_all=attention_mask_all,
                    packed_seq_lens_all=packed_seq_lens_all,
                    dir_sequences_all=dir_sum["sequences"],
                    dir_num_actions_all=dir_sum["num_actions"],
                    dir_attention_mask_all=dir_sum["attention_masks"],
                    dir_packed_seq_lens_all=dir_sum["packed_seq_lens"],
                    summary_log_probs_=summary_log_probs_,
                    summary_log_probs=summary_log_probs,
                    summary_action_mask_all=summary_action_mask_all,                        
                )
                #raise ValueError(f"Summary log probs shape mismatch: {sum_summary_action} vs {summary_log_probs.shape[1]}")
                ray.logger.info(f"Summary log probs shape mismatch: {sum_summary_action} vs {summary_log_probs.shape[1]}")

        # 6. calculate kl divergence

        experiences = []
        if self.critic_model is not None:
            values = values[: len(sequences_all)]
        if self.ref_model is not None:
            base_log_probs = base_log_probs[: len(sequences_all)]
        action_log_probs = action_log_probs[: len(sequences_all)]
        sum_idx, sum_a_idx = 0, 0

        for i in range(len(sequences_all)):

            if self.cfg.multi_attempt:
                sys_mask = packed_create_token_mask(sequences_all[i], self.tokenizer, num_actions_all[i], packed_seq_lens_all[i])
            elif self.cfg.summary:
                sys_mask = (sequence_types_all[i] % 2) == 0 # even number are user inputs
                sys_mask = sys_mask[sequence_types_all[i] != 0].unsqueeze(0)
            else:
                sys_mask = None

            response_length = torch.Tensor(num_actions_all[i]).unsqueeze(0)
            total_length = torch.Tensor(packed_seq_lens_all[i]).unsqueeze(0)
            if self.cfg.multi_attempt or self.cfg.summary:
                action_mask = torch.logical_not(sys_mask)
            else:
                action_mask = None

            if self.ref_model is not None:
                kl = compute_approx_kl(
                    action_log_probs[i],
                    base_log_probs[i],
                    action_mask=action_mask,
                    use_kl_estimator_k3=self.cfg.use_kl_estimator_k3,
                    use_abs_kl=self.cfg.use_abs_kl,
                )
                kl_max = torch.max(kl.abs(), dim=-1)[0]
                kl_mean = masked_mean(kl, action_mask, dim=-1)
            else:
                kl = None
                kl_max = None
                kl_mean = None

            custom_rewards = custom_rewards_all[i] # list of list of tensor with shape (num_action,)

            if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                summary_action_mask = summary_action_mask_all[i]                       
                
                consist_returns = []
                custom_rewards_ = []              
                consist_scores = [] 
                
                consist_returns_count = 0                   
                consist_score_count = 0

                a_idx = 0        
                s_idx = 0     
                c = self.cfg.summary_consist_coef

                for seq_n, custom_reward in enumerate(custom_rewards):
                    # custom_reward is a tensor of shape (num_actions,)
                    seq_len = packed_seq_lens_all[i][seq_n]

                    num_actions = len(custom_reward)
                    summary_len = torch.sum(summary_action_mask[0, s_idx:s_idx+seq_len]).item()
                    consist_score_count += summary_len                    

                    if summary_len >= 1:                                    
                        summary_len_ = summary_num_actions_all[sum_idx]
                        
                        summary_log_probs_ = summary_log_probs[:, sum_a_idx:sum_a_idx+summary_len_]                        
                        sum_a_idx = sum_a_idx + summary_len_
                        sum_idx += 1                        
                        
                        if summary_len_ != summary_len:
                            # raise info warnning if abs different is larger than 0
                            if abs(summary_len_ - summary_len) > 0:
                                ray.logger.info(f"Summary length mismatch: {summary_len_} vs {summary_len}; seq: {self.tokenizer.decode(sequences_all[i][0, s_idx+seq_len-summary_len-10: s_idx+seq_len-summary_len+10])}")                            
                            if summary_len < summary_len_:
                                summary_log_probs_ = summary_log_probs_[:, :summary_len]
                            else:
                                summary_log_probs_ = torch.cat([torch.zeros(1, summary_len-summary_len_, device=summary_log_probs_.device, dtype=summary_log_probs_.dtype), summary_log_probs_], dim=1)

                        consist_score = summary_log_probs_
                        consist_score = consist_score.squeeze(0)
                        consist_score[torch.isnan(consist_score)] = 0.                            

                        if self.cfg.summary_consist_mean:
                            consist_score = consist_score / summary_len

                        consist_score = cum_clip(consist_score, min_val=-0.95/c, max_val=+0.95/c)                                            
                        consist_score = consist_score * self.cfg.summary_reward_coef                        

                        if self.cfg.summary_consist_end:
                            consist_score_sum = torch.sum(consist_score)
                            consist_score.zero_()  # More explicit in PyTorch
                            consist_score[-1] = consist_score_sum  # Ensure it has at least one element

                        consist_score_appended = torch.zeros_like(custom_reward)
                        consist_score_appended[-summary_len:] = consist_score                        
                        
                        consist_scores.append(consist_score_appended)
                        custom_rewards_.append(custom_reward + c * consist_score_appended)
                        consist_returns.append(torch.sum(c * consist_score_appended))
                        consist_returns_count += 1                        
                    else:
                        custom_rewards_.append(custom_reward)
                        consist_returns.append(torch.tensor(0., dtype=custom_reward.dtype, device=custom_reward.device))

                    # ray.logger.info(f"i: {i} custom_reward shape: {custom_reward.shape},  kl shape: {kl.shape}, num_actions: {num_actions}")                    
                    a_idx += num_actions                    
                    s_idx += seq_len

                custom_rewards = custom_rewards_
                if len(consist_scores) > 0:
                    consist_scores = torch.cat(consist_scores)
                else:
                    consist_scores = None                             

            info = {
                "kl": kl_mean,
                "kl_max": kl_max,
                "reward": None,
                "custom_rewards": custom_rewards,
                "response_length": response_length,
                "total_length": total_length,
                "num_actions": num_actions_all[i],                
            }

            if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
                info.update({                    
                    "consist_score_sum": torch.sum(consist_scores).item() if consist_scores is not None else 0.,
                    "consist_score_count": consist_score_count,
                    "consist_score_max": torch.max(consist_scores.abs(), dim=-1)[0].item() if consist_scores is not None else 0.,
                    "consist_returns_first": consist_returns[0].item(),
                    "consist_returns_sum": torch.sum(torch.stack(consist_returns)).item(),
                    "consist_returns_count": consist_returns_count,
                })

            experiences.append(
                Experience(
                    sequences_all[i],
                    action_log_probs[i],
                    base_log_probs[i] if self.ref_model is not None else None,
                    values[i] if self.critic_model is not None else None,
                    None,
                    None,
                    attention_mask_all[i],
                    None,
                    response_length,
                    torch.Tensor(packed_seq_lens_all[i]).unsqueeze(0),
                    info,
                    kl,
                    sys_mask,
                    None
                )
            )
        return experiences        

    @torch.no_grad()
    async def generate_vllm(
        self,
        gen_func: Callable[[List[str]], Awaitable[List[str | Any]]],
        prompts: List[str],
        extras: Optional[List[dict]] = None,
        **kwargs,
    ) -> List[str | Any]:
        from vllm import SamplingParams

        sampling_params = SamplingParams(
            temperature=kwargs.get("temperature", 1.0),
            top_p=kwargs.get("top_p", 1.0),
            top_k=kwargs.get("top_k", -1),
            max_tokens=kwargs.get("max_new_tokens", 1024),
            min_tokens=kwargs.get("min_new_tokens", 1),
            skip_special_tokens=kwargs.get("skip_special_tokens", False),
        )

        responses, _ = await gen_func(prompts=prompts, sampling_params=sampling_params, use_tqdm=False)
        return responses

    def build_dataloader(self, dataset):
        # prepare dataloader
        prompts_dataloader = DataLoader(
            dataset, batch_size=self.cfg.rollout_batch_size, shuffle=True, collate_fn=dataset.collate_fn, num_workers=8
        )
        self.num_update_steps_per_episodes = (
            len(dataset) * self.cfg.n_samples_per_prompt // self.cfg.train_batch_size * self.cfg.max_epochs
        )
        max_steps = math.ceil(self.cfg.num_episodes * self.num_update_steps_per_episodes)
        self._max_steps = max_steps
        return prompts_dataloader

    async def build_models(self, PolicyRayActor, CriticRayActor, RefRayActor, RewardRayActor=None):
        cfg = self.cfg
        pg = None

        if cfg.colocate_all:
            assert (
                cfg.actor_num_nodes == cfg.critic_num_nodes
                and cfg.actor_num_gpus_per_node == cfg.critic_num_gpus_per_node
                and cfg.actor_num_nodes == cfg.ref_num_nodes
                and cfg.actor_num_gpus_per_node == cfg.ref_num_gpus_per_node
                and cfg.actor_num_gpus_per_node == 1
                and cfg.actor_num_nodes == cfg.vllm_num_engines
            ), "num_nodes and num_gpus_per_node must be the same when colocate all models and each actor has only one gpu."
            pg = self.colocate_pg

            policy_model = PPORayActorGroup(
                cfg.actor_num_nodes,
                cfg.actor_num_gpus_per_node,
                PolicyRayActor,
                pg=pg,
                num_gpus_per_actor=0.2,
            )
            if cfg.ref_pretrain is not None:
                ref_model = PPORayActorGroup(
                    cfg.ref_num_nodes,
                    cfg.ref_num_gpus_per_node,
                    RefRayActor,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                )
            else:
                ref_model = None
            if cfg.critic_pretrain:
                critic_model = PPORayActorGroup(
                    cfg.critic_num_nodes,
                    cfg.critic_num_gpus_per_node,
                    CriticRayActor,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                )
            else:
                critic_model = None

            # multiple reward models
            if RewardRayActor is not None and cfg.reward_pretrain:
                reward_pretrains = cfg.reward_pretrain.split(",")
                reward_models = []
                for _ in reward_pretrains:
                    reward_models.append(
                        PPORayActorGroup(
                            cfg.reward_num_nodes,
                            cfg.reward_num_gpus_per_node,
                            RewardRayActor,
                            pg=pg,
                            num_gpus_per_actor=0.2,
                        )
                    )
            else:
                reward_models = None

        else:
            if cfg.colocate_actor_ref:
                assert (
                    cfg.actor_num_nodes == cfg.ref_num_nodes
                    and cfg.actor_num_gpus_per_node == cfg.ref_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate actor and ref model."

                bundles = [
                    {"GPU": cfg.actor_num_gpus_per_node, "CPU": cfg.actor_num_gpus_per_node * 2}
                    for _ in range(cfg.actor_num_nodes)
                ]
                pg = placement_group(bundles, strategy="PACK")
                ray.get(pg.ready())

            policy_model = PPORayActorGroup(
                cfg.actor_num_nodes,
                cfg.actor_num_gpus_per_node,
                PolicyRayActor,
                pg=pg,
                num_gpus_per_actor=0.75 if pg else 1,
            )
            if cfg.ref_pretrain is not None:
                ref_model = PPORayActorGroup(
                    cfg.ref_num_nodes,
                    cfg.ref_num_gpus_per_node,
                    RefRayActor,
                    pg=pg,
                    num_gpus_per_actor=0.25 if pg else 1,
                )
            else:
                ref_model = None

            # if colocated, create placement group for critic and reward model explicitly.
            pg = None
            if cfg.colocate_critic_reward:
                assert (
                    cfg.critic_num_nodes == cfg.reward_num_nodes
                    and cfg.critic_num_gpus_per_node == cfg.reward_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate critic and reward model."

                bundles = [
                    {"GPU": cfg.critic_num_gpus_per_node, "CPU": cfg.critic_num_gpus_per_node * 2}
                    for _ in range(cfg.critic_num_nodes)
                ]
                pg = placement_group(bundles, strategy="PACK")
                ray.get(pg.ready())

            if cfg.critic_pretrain:
                critic_model = PPORayActorGroup(
                    cfg.critic_num_nodes,
                    cfg.critic_num_gpus_per_node,
                    CriticRayActor,
                    pg=pg,
                    num_gpus_per_actor=0.75 if pg else 1,
                )
            else:
                critic_model = None

            # multiple reward models
            if RewardRayActor is not None and cfg.reward_pretrain:
                reward_pretrains = cfg.reward_pretrain.split(",")
                reward_models = []
                for _ in reward_pretrains:
                    reward_models.append(
                        PPORayActorGroup(
                            cfg.reward_num_nodes,
                            cfg.reward_num_gpus_per_node,
                            RewardRayActor,
                            pg=pg,
                            num_gpus_per_actor=0.25 if pg else 1,
                        )
                    )
            else:
                reward_models = None

        if not cfg.colocate_all:
            refs = []
            if cfg.ref_pretrain is not None:
                refs.extend(ref_model.async_init_model_from_pretrained(self.strategy, cfg.ref_pretrain))
            refs.extend(policy_model.async_init_model_from_pretrained(self.strategy, cfg.pretrain))
            if cfg.critic_pretrain:
                refs.extend(critic_model.async_init_model_from_pretrained(self.strategy, cfg.critic_pretrain))
            if cfg.reward_pretrain:
                for reward_model, reward_pretrain in zip(reward_models, reward_pretrains):
                    refs.extend(reward_model.async_init_model_from_pretrained(self.strategy, reward_pretrain))
            await asyncio.gather(*refs)
            await policy_model.async_run_method("_set_pad_token_id", self.tokenizer.pad_token_id)
        else:            
            await asyncio.gather(*policy_model.async_init_model_from_pretrained(self.strategy, cfg.pretrain))
            await policy_model.async_run_method("_set_pad_token_id", self.tokenizer.pad_token_id)
            await policy_model.offload_to_cpu()
            if cfg.critic_pretrain:
                await asyncio.gather(*critic_model.async_init_model_from_pretrained(self.strategy, cfg.critic_pretrain))
                await critic_model.offload_to_cpu()
            if cfg.reward_pretrain:
                for reward_model, reward_pretrain in zip(reward_models, reward_pretrains):
                    await asyncio.gather(*reward_model.async_init_model_from_pretrained(self.strategy, reward_pretrain))
            if cfg.ref_pretrain is not None:
                await asyncio.gather(*ref_model.async_init_model_from_pretrained(self.strategy, cfg.ref_pretrain))

        self.policy_model = policy_model
        self.critic_model = critic_model
        self.ref_model = ref_model
        self.reward_model = reward_models

        logger.info("Initialized policy/ref/critic/reward models")

    async def ppo_local_train_policy(self, replay_buffers: List[NaiveReplayBuffer], summary_buffers: List[NaiveReplayBuffer], global_steps: int):
        if global_steps > self.cfg.freezing_actor_steps:
            async with Timer("Policy model training"):
                status = await self.policy_model.async_ppo_train(global_steps, replay_buffers, summary_buffers)
            self.writer.add_scalar("ppo_clip_count", status[0]["clip_ratio"], global_steps)
            self.writer.add_scalar("policy_update_steps", status[0]["policy_update_steps"], global_steps)
            self.writer.add_scalar("policy_entropy", status[0]["entropy"], global_steps)
            await self.policy_model.async_run_method("empty_cache")
        if self.cfg.colocate_all:
            async with Timer("Backload vllm engines to gpu"):
                await self._backload_vllm_engines()
        async with Timer("Broadcast actor weights to vllm engines"):
            await self._sync_policy_weights_to_vllm()

        if global_steps > self.cfg.freezing_actor_steps:
            return status[0]

    async def ppo_local_train_critic(self, replay_buffers: List[NaiveReplayBuffer], global_steps: int):
        async with Timer("Critic model training"):
            status = await self.critic_model.async_ppo_train(global_steps, replay_buffers)
        if critic_loss := status[0].get("critic_loss", None):
            self.writer.add_scalar("critic_loss", critic_loss, global_steps)
            self.writer.add_scalar("critic_update_steps", status[0]["critic_update_steps"], global_steps)
        return status[0]

    async def custom_reward_fn(
        self,
        prompts: List[str],
        outputs: List[Any],
        extras: List[dict],
        reward_model_fn: Callable[[List[str], List[str]], Awaitable[torch.Tensor]],
    ) -> Tuple[List[str], List[str], List[torch.Tensor]]:
        raise NotImplementedError("custom reward function is not supported yet")

    @torch.no_grad()
    async def _calc_advantages_and_returns(self, experience: Experience):
        num_actions = experience.info["num_actions"]
        custom_reward_ = experience.info['custom_rewards']
        if sum([len(x) for x in custom_reward_]) > 17000:
            ray.logger.info(f"Custom reward is too large - custom_rewards shape: {sum([len(x) for x in custom_reward_])}")
            save_debug_data(
                sequences_all=experience.sequences,
                num_actions_all=num_actions,
                attention_mask_all=experience.attention_mask,
                packed_seq_lens_all=experience.packed_seq_lens,
                action_log_probs=experience.action_log_probs,
                base_log_probs=experience.base_action_log_probs,
                values=experience.values,
                custom_rewards=experience.info["custom_rewards"],
                kl=experience.info["kl"],
                sys_mask=experience.sys_mask,
            )


        reward = await compute_reward.remote(
            experience.info["reward"],
            self.cfg.init_kl_coef,
            experience.kl,
            custom_rewards=experience.info["custom_rewards"],
            action_mask=experience.action_mask,
            num_actions=num_actions,
            reward_clip_range=self.cfg.reward_clip_range,
            use_kl_loss=self.cfg.use_kl_loss,
        )
        
        if self.cfg.summary:
            gamma = self.cfg.gamma * (1 - experience.sys_mask.float())
        else:
            gamma = self.cfg.gamma

        experience.advantages, experience.returns = await get_advantages_and_returns.remote(
            experience.values,
            reward,
            experience.action_mask,
            num_actions,
            gamma,
            self.cfg.lambd,
            packing=True,
        )
        #experience.save_data(reward=reward)

        return_sums = reward.sum(dim=-1)
        return_sums /= len(num_actions)
        experience.info["return"] = return_sums
        experience.kl = None

        if self.cfg.actor_value_coef > 0.:
            actor_target_values = []
            for gamma in self.cfg.actor_value_gammas:
                _, actor_target_value = await get_advantages_and_returns.remote(
                    experience.values,
                    reward,
                    experience.action_mask,
                    num_actions,
                    gamma,
                    self.cfg.lambd,
                    packing=True,
                )
                actor_target_values.append(actor_target_value)
            experience.actor_target_values = torch.stack(actor_target_values, dim=-1)

        avg_rewards = return_sums.mean().item()
        avg_kl = experience.info["kl"].mean().item() if experience.info["kl"] is not None else 0
        avg_kl_max = experience.info["kl_max"].mean().item() if experience.info["kl_max"] is not None else 0

        avg_response_length = experience.info["response_length"].mean().item()
        if experience.info["reward"] is not None:
            avg_orm_score = experience.info["reward"].mean().item()
        else:
            avg_orm_score = 0

        if experience.info["custom_rewards"] is not None:

            def func(x):
                return [r.sum() for r in x]

            avg_custom_rewards = torch.stack(func(experience.info["custom_rewards"])).mean().item()
            # experience.info["avg_custom_rewards"] = torch.stack(func(experience.info["custom_rewards"]))
        else:
            avg_custom_rewards = 0

        del experience.info["num_actions"]
        del experience.info["custom_rewards"]
        del experience.info["reward"]
        del experience.info["kl_max"]
        experience.to_device("cpu")

        # for replay buffer split batch
        num_packed_samples = len(num_actions)
        experience.info["response_length"] = torch.Tensor(experience.info["response_length"]).mean().unsqueeze(0)
        experience.info["total_length"] = torch.Tensor(experience.info["total_length"]).mean().unsqueeze(0)

        metrics = {
            "avg_rewards": avg_rewards,
            "avg_kl": avg_kl,
            "avg_kl_max": avg_kl_max,
            "avg_response_length": avg_response_length,
            "avg_orm_score": avg_orm_score,
            "avg_custom_rewards": avg_custom_rewards,
            "avg_advantages": experience.advantages.mean().item(),
            "avg_advantages_abs": experience.advantages.abs().mean().item(),
        }
        
        if self.cfg.summary and self.cfg.summary_consist_coef > 0 and not self.cfg.summary_skip:
            keys = list(experience.info.keys())
            for k in keys:
                if k.startswith("consist_"):
                    metrics[k] = experience.info[k]
                    del experience.info[k]

        return experience, metrics

    def _convert_prompts_outputs_to_batch_tensors(self, prompts: List[str], outputs: List[str]):
        # This function is used when not packing samples
        # concat all outputs to following format:
        #
        # | [PAD] [PAD] token token token | token token [EOS] [PAD] |
        # | token token token token token | token token [EOS] [PAD] |
        # | [PAD] [PAD] [PAD] token token | token token token [EOS] |
        # |<---------- prompt ----------->|<-------- answer ------->|
        max_input_len, max_output_len = 0, 0
        prompt_token_lens, response_token_lens = [], []
        inputs_token_ids, outputs_token_ids = [], []
        for prompt, output in zip(prompts, outputs):
            input_token_ids = self._tokenize(prompt, self.cfg.prompt_max_len, padding=False)
            response_token_ids = self._tokenize(output, self.cfg.generate_max_len, padding=False)

            inputs_token_ids.append(input_token_ids)
            outputs_token_ids.append(response_token_ids)

            prompt_token_len = len(input_token_ids)
            response_token_len = len(response_token_ids)
            prompt_token_lens.append(prompt_token_len)
            response_token_lens.append(response_token_len)

            max_input_len = max(max_input_len, prompt_token_len)
            max_output_len = max(max_output_len, response_token_len)

        pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
        sequences = []
        for i, prompt in enumerate(prompts):
            # left padding input
            input_len = prompt_token_lens[i]
            input_ids = [pad_token_id] * (max_input_len - input_len) + list(inputs_token_ids[i])

            # right padding output
            output_len = response_token_lens[i]
            output_ids = list(outputs_token_ids[i]) + [pad_token_id] * (max_output_len - output_len)

            # replace last token with eos_token_id if it is not eos_token_id, keep the total length of output_ids
            # output_ids[output_len - 1] = eos_token_id

            # concat input and output
            sequences.append(input_ids + output_ids)

        sequences = torch.tensor(sequences)

        sequences, attention_mask, action_mask = self._process_sequences(
            sequences, max_input_len, eos_token_id, pad_token_id
        )
        return sequences, attention_mask, action_mask

    def _convert_prompts_outputs_to_batch_tensors_packing(
        self, 
        prompts: List[str], 
        outputs: List[str], 
        output_ids: Optional[List[int]],
        custom_rewards: Optional[List[torch.Tensor]],         
        packing_max_len: int
    ):
        ret_sequences = []
        ret_sequence_types = []
        ret_attention_masks = []
        ret_num_actions = []
        ret_packed_seq_lens = []
        if custom_rewards is not None:
            ret_custom_rewards = []
        else:
            ret_custom_rewards = None

        if prompts is not None:
            assert (
                len(prompts) == len(outputs) and len(prompts) > 0
            ), "prompts and outputs must have the same length and length must be greater than 0"

        def _new_instance():
            out_sequence = torch.full((packing_max_len,), torch.tensor(self.tokenizer.pad_token_id), dtype=torch.long)
            out_sequence_type = torch.zeros((packing_max_len,), dtype=torch.int) if output_ids is not None else None
            out_attention_mask = torch.zeros((packing_max_len,), dtype=torch.int)
            out_num_actions = []
            out_packed_seq_lens = []
            rewards = [] if custom_rewards else None
            seq_offset = 0
            seq_index = 0
            return (
                out_sequence,
                out_sequence_type,
                out_attention_mask,
                out_num_actions,
                out_packed_seq_lens,
                rewards,
                seq_offset,
                seq_index,
            )

        def _accumulate(
            out_sequence,
            out_sequence_type,
            out_attention_mask,
            out_num_actions,
            out_packed_seq_lens,
            rewards,
            seq_offset,
            seq_index,
            sequence,
            sequence_type,
            attention_mask,
            num_action,
            total_len,
            custom_rewards,
            i,
        ):
            out_sequence[seq_offset : seq_offset + total_len] = torch.tensor(sequence)            
            out_attention_mask[seq_offset : seq_offset + total_len] = seq_index + 1
            out_num_actions.append(num_action)
            out_packed_seq_lens.append(total_len)
            if custom_rewards:
                if len(custom_rewards[i]) != num_action:
                    ray.logger.info(f"custom rewards length {len(custom_rewards[i])} != num action {num_action}")

                rewards.append(custom_rewards[i])
            if out_sequence_type is not None:
                out_sequence_type[seq_offset : seq_offset + total_len] = torch.tensor(sequence_type)
            return seq_offset + total_len, seq_index + 1

        sequences = []
        sequence_types = []
        attention_masks = []
        num_actions = []
        total_lens = []

        if output_ids is None:
            input_token_ids = self._tokenize(prompts, self.cfg.prompt_max_len, padding=False)
            response_token_ids = self._tokenize(outputs, self.cfg.generate_max_len, padding=False)
        else:
            input_token_ids = [out[0] for out in output_ids] # output_ids is list of list of integer
            response_token_ids = [sum(out[1:], []) for out in output_ids]
            seq_types = []

            for n, out in enumerate(output_ids):
                seq_types_n = []
                for m, res in enumerate(out):
                    seq_types_n.extend([m] * len(res))
                seq_types.append(seq_types_n)

        custom_rewards_ = []
        for n, input_ids, response_ids in zip(range(len(input_token_ids)), input_token_ids, response_token_ids):
            seq = input_ids + response_ids
            if len(seq) > packing_max_len: continue
            sequences.append(seq)
            sequence_types.append(seq_types[n] if output_ids is not None else None)
            attention_masks.append(torch.ones((len(input_ids) + len(response_ids),), dtype=torch.float32))
            num_actions.append(len(response_ids))
            total_lens.append(len(input_ids) + len(response_ids))
            custom_rewards_.append(custom_rewards[n] if custom_rewards is not None else None)

        if custom_rewards is not None:
            custom_rewards = custom_rewards_
        else:
            custom_rewards = None

        # make packed sequences
        (
            out_sequence,
            out_sequence_type,
            out_attention_mask,
            out_num_actions,
            out_packed_seq_lens,
            rewards,
            seq_offset,
            seq_index,
        ) = _new_instance()
        for i, (sequence, sequence_type, attention_mask, num_action, total_len) in enumerate(
            zip(sequences, sequence_types, attention_masks, num_actions, total_lens)
        ):
            if seq_offset + total_len < packing_max_len:
                seq_offset, seq_index = _accumulate(
                    out_sequence,
                    out_sequence_type,
                    out_attention_mask,
                    out_num_actions,
                    out_packed_seq_lens,
                    rewards,
                    seq_offset,
                    seq_index,
                    sequence,
                    sequence_type,
                    attention_mask,
                    num_action,
                    total_len,
                    custom_rewards,
                    i,
                )
            elif seq_offset + total_len == packing_max_len:
                seq_offset, seq_index = _accumulate(
                    out_sequence,
                    out_sequence_type,
                    out_attention_mask,
                    out_num_actions,
                    out_packed_seq_lens,
                    rewards,
                    seq_offset,
                    seq_index,
                    sequence,
                    sequence_type,
                    attention_mask,
                    num_action,
                    total_len,
                    custom_rewards,
                    i,
                )
                valid_size = out_attention_mask.nonzero().size(0)
                ret_sequences.append(out_sequence[:valid_size].unsqueeze(0))
                ret_attention_masks.append(out_attention_mask[:valid_size].unsqueeze(0))
                ret_num_actions.append(out_num_actions)
                ret_packed_seq_lens.append(out_packed_seq_lens)
                if output_ids is not None:
                    ret_sequence_types.append(out_sequence_type[:valid_size].unsqueeze(0))
                if custom_rewards:
                    ret_custom_rewards.append(rewards)
                (
                    out_sequence,
                    out_sequence_type,
                    out_attention_mask,
                    out_num_actions,
                    out_packed_seq_lens,
                    rewards,
                    seq_offset,
                    seq_index,
                ) = _new_instance()
            elif seq_offset + total_len > packing_max_len:
                if seq_offset > 0:                    
                    valid_size = out_attention_mask.nonzero().size(0)
                    ret_sequences.append(out_sequence[:valid_size].unsqueeze(0))
                    ret_attention_masks.append(out_attention_mask[:valid_size].unsqueeze(0))
                    ret_num_actions.append(out_num_actions)
                    ret_packed_seq_lens.append(out_packed_seq_lens)
                    if output_ids is not None:
                        ret_sequence_types.append(out_sequence_type[:valid_size].unsqueeze(0))
                    if custom_rewards:
                        ret_custom_rewards.append(rewards)
                    
                    (
                        out_sequence,
                        out_sequence_type,
                        out_attention_mask,
                        out_num_actions,
                        out_packed_seq_lens,
                        rewards,
                        seq_offset,
                        seq_index,
                    ) = _new_instance()
                    seq_offset, seq_index = _accumulate(
                        out_sequence,
                        out_sequence_type,
                        out_attention_mask,
                        out_num_actions,
                        out_packed_seq_lens,
                        rewards,
                        seq_offset,
                        seq_index,
                        sequence,
                        sequence_type,
                        attention_mask,
                        num_action,
                        total_len,
                        custom_rewards,
                        i,
                    )

        if seq_offset > 0:
            valid_size = out_attention_mask.nonzero().size(0)
            ret_sequences.append(out_sequence[:valid_size].unsqueeze(0))
            ret_attention_masks.append(out_attention_mask[:valid_size].unsqueeze(0))
            ret_num_actions.append(out_num_actions)
            ret_packed_seq_lens.append(out_packed_seq_lens)
            if custom_rewards:
                ret_custom_rewards.append(rewards)
            if output_ids is not None:
                ret_sequence_types.append(out_sequence_type[:valid_size].unsqueeze(0))

        if output_ids is None:
            ret_sequence_types = [None for _ in range(len(ret_sequences))]

        return ret_sequences, ret_sequence_types, ret_attention_masks, ret_num_actions, ret_packed_seq_lens, ret_custom_rewards
    
    def _filter_summary_to_packing(self, info, packing_max_len):
        new_output_ids = []

        for (
            response_status,
            response_ids,
        ) in zip(
            info["response_status"],
            info["response_ids"],
        ):  
            if response_status[0] == 1: # correct fast response
                new_output_ids.append(response_ids[:2])
            if len(response_status) >= 4 and response_status[3] == 1: # correct summary
                if len(response_ids) < 7:
                    ray.logger.info(f"summary response status is not correct; response_status: {response_status} and response_ids: {response_ids}")
                    continue
                new_output_ids.append([response_ids[0], response_ids[7]])

        if len(new_output_ids) > 0:        
            (
                ret_sequences,
                ret_sequence_types,
                ret_attention_masks,
                ret_num_actions,
                ret_packed_seq_lens,
                _,
            ) = self._convert_prompts_outputs_to_batch_tensors_packing(
                None, None, new_output_ids, None, packing_max_len,
            )
            return ret_sequences, ret_sequence_types, ret_attention_masks, ret_num_actions, ret_packed_seq_lens
        else:
            return [], [], [], [], []
        
    def _filter_dir_summary_to_packing(self, info, packing_max_len):
        new_output_ids = []

        for (
            response_status,
            response_ids,
        ) in zip(
            info["response_status"],
            info["response_ids"],
        ):  
            if len(response_status) >= 4: # has summary
                if len(response_ids) < 7:
                    ray.logger.info(f"summary response status is not correct; response_status: {response_status} and response_ids: {response_ids}")
                    continue
                new_output_ids.append([response_ids[0], response_ids[7]])

        if len(new_output_ids) > 0:        
            (
                ret_sequences,
                ret_sequence_types,
                ret_attention_masks,
                ret_num_actions,
                ret_packed_seq_lens,
                _,
            ) = self._convert_prompts_outputs_to_batch_tensors_packing(
                None, None, new_output_ids, None, packing_max_len,
            )
            return ret_sequences, ret_sequence_types, ret_attention_masks, ret_num_actions, ret_packed_seq_lens
        else:
            return [], [], [], [], []
        
    def _get_dp_group_models(self, dp_rank: int, model_type: str = ""):
        model = getattr(self, model_type)
        if model_type == "reward_model":
            model = model[0]
        return model._actor_handlers[dp_rank]

    def _split_dp_batch(self, batch, num_dp, drop_last=False):
        # Convert batch tuple to list of lists, handling None values
        batch_lists = []
        batch_size = None
        for item in batch:
            if item is not None:
                if batch_size is None:
                    batch_size = len(item)
                batch_lists.append(item)
            else:
                batch_lists.append(None)

        if drop_last:
            dp_size = batch_size // num_dp
        else:
            dp_size = (batch_size + num_dp - 1) // num_dp
        valid_size = dp_size * num_dp

        if not drop_last:
            padding_index = None
            for i in range(len(batch_lists)):
                if batch_lists[i] is not None and (
                    isinstance(batch_lists[i], torch.Tensor) or isinstance(batch_lists[i], list)
                ):
                    padding_size = valid_size - len(batch_lists[i])
                    if padding_size > 0:
                        if padding_index is None:
                            if padding_size > len(batch_lists[i]):
                                padding_index = random.choices(range(len(batch_lists[i])), k=padding_size)
                            else:
                                padding_index = random.sample(range(len(batch_lists[i])), padding_size)
                        if isinstance(batch_lists[i], torch.Tensor):
                            batch_lists[i] = torch.cat([batch_lists[i], batch_lists[i][padding_index]], dim=0)
                        elif isinstance(batch_lists[i], list):
                            batch_lists[i] = batch_lists[i] + [batch_lists[i][j] for j in padding_index]

        for i in range(num_dp):
            # Extract micro batch for each input list
            micro_batch = []
            for batch_list in batch_lists:
                if batch_list is None:
                    micro_batch.append(None)
                elif isinstance(batch_list, torch.Tensor) or isinstance(batch_list, list):
                    micro_batch.append(batch_list[i * dp_size : (i + 1) * dp_size])
                else:
                    micro_batch.append(batch_list)
            yield tuple(micro_batch)

    def _split_dp_batch_dynamic_balance(self, batch, num_dp, balanced_values):
        batch = list(batch)
        assert len(batch) == len(balanced_values), "batch and balanced_values must have the same length"
        results = self._split_weighted_objects(zip(balanced_values, batch), num_dp)
        # re organize to the original format
        for i in range(num_dp):
            ret = [[] for _ in range(len(results[i][0]))]
            for sample in results[i]:
                for j, v in enumerate(sample):
                    ret[j].append(v)
            yield ret

    def _split_weighted_objects(self, items, n):
        result = [[] for _ in range(n)]

        heap = [(0, i) for i in range(n)]
        heapify(heap)

        sorted_items = sorted(items, key=lambda x: x[0], reverse=True)

        for weight, obj in sorted_items:
            current_sum, index = heappop(heap)
            result[index].append(obj)
            heappush(heap, (current_sum + weight, index))

        return result

    async def _split_and_run_micro_batch(self, async_fn, batch_args, micro_size):
        # Ensure batch_args is a sequence of lists with equal length
        batch_size = len(batch_args[0])
        results = []
        # Process in micro batches
        for i in range(0, batch_size, micro_size):
            # Take slice i:i+micro_size from each argument
            micro_batch_args = []
            for arg in batch_args:
                if arg is not None:
                    if not isinstance(arg, torch.Tensor) and not isinstance(arg, list):
                        micro_batch_args.append(arg)
                    elif micro_size > 1 or isinstance(arg, torch.Tensor):
                        micro_batch_args.append(arg[i : i + micro_size])
                    else:
                        micro_batch_args.append(arg[i])
                else:
                    micro_batch_args.append(None)
            results.append(await async_fn(*micro_batch_args))
        return results

    def _get_generate_function(self, dp_rank: int):
        llm = self.vllm_engines[dp_rank % len(self.vllm_engines)]

        async def generate(prompts: List[str], truncate_prompt=True, **kwargs):
            if truncate_prompt:
                prompt_token_ids = self._tokenize(prompts, self.cfg.prompt_max_len, padding=False)
            else:
                prompt_token_ids = self._tokenize(prompts, padding=False)
            outputs = await llm.generate.remote(prompt_token_ids=prompt_token_ids, **kwargs)

            out = defaultdict(list)
            for i, prompt in enumerate(prompts):
                out["stop_reason"].append(outputs[i].outputs[0].finish_reason)
                out["response"].append(outputs[i].outputs[0].text)

                if self.cfg.multi_attempt:
                    out["prompt"].append(outputs[i].prompt)
                    out["answer_status"].append(outputs[i].outputs[0].answer_status)
                    out["final_answer"].append(outputs[i].outputs[0].final_answer)
                    out["attempt_used"].append(outputs[i].outputs[0].attempt_used)

                if self.cfg.summary:
                    out["prompt"].append(outputs[i].prompt)
                    out["response_ids"].append(outputs[i].outputs[0].token_ids)
                    out["response_status"].append(outputs[i].outputs[0].response_status)
                    out["answer_status"].append(outputs[i].outputs[0].answer_status)
                    out["final_answer"].append(outputs[i].outputs[0].final_answer)
                    out["all_answers"].append(outputs[i].outputs[0].all_answers)
                    out["rewards"].append(outputs[i].outputs[0].rewards)
            return out

        return generate

    def _process_sequences(self, sequences: torch.Tensor, input_len, eos_token_id, pad_token_id):
        attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long)
        seq_length = attention_mask.size(1)

        eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1)
        sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id)

        # For Llama3 and Qwen2 models, there are some eos_tokens in the middle of the prompt.
        first_token_indices = attention_mask.long().argmax(dim=1, keepdim=True)
        mask = torch.arange(seq_length).unsqueeze(0).expand(sequences.size(0), -1).to(device=sequences.device)
        attention_mask = (mask >= first_token_indices) & (mask <= eos_indices).to(dtype=torch.long)

        # in RL, state_i (current token) + action_i (next token) -> state_i+1 (next token)
        state_seq = sequences[:, input_len - 1 : -1]
        action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id)
        action_mask[:, 0] = 1

        return sequences, attention_mask, action_mask

    def _tokenize(self, texts, max_length=99999999, padding=False, device=None, add_special_tokens=False):
        if not padding:
            # when padding is False, return tokenized texts as list
            return encode_prompts(
                texts,
                self.tokenizer,
                padding=False,
                add_special_tokens=add_special_tokens,
                max_length=max_length,
                truncation=True,
            )
        batch = encode_prompts(
            texts,
            self.tokenizer,
            padding=True,            
            add_special_tokens=add_special_tokens,
            return_tensors="pt",
            max_length=max_length,            
            truncation=True,
        )
        return {k: v.to(device) for k, v in batch.items()}

    def _detokenize(self, token_ids):
        return self.tokenizer.decode(token_ids, skip_special_tokens=False)

    async def _offload_vllm_engines(self):
        offload_tasks = []
        for engine in self.vllm_engines:
            offload_tasks.append(engine.offload_to_cpu.remote())
        await asyncio.gather(*offload_tasks)

    async def _backload_vllm_engines(self):
        backload_tasks = []
        for engine in self.vllm_engines:
            backload_tasks.append(engine.backload_to_gpu.remote())
        await asyncio.gather(*backload_tasks)

    async def _sync_policy_weights_to_vllm(self):
        if self.cfg.colocate_all:
            await self.policy_model.async_run_method("_broadcast_to_vllm_cudaipc", self.vllm_engines)
        else:
            await self.policy_model.async_run_method("_broadcast_to_vllm", self.vllm_engines)
