import itertools
import math
import os
import socket
import time
from datetime import timedelta
from typing import Callable, Dict, List

import deepspeed
import ray
import torch
import torch.distributed
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.trainer import get_scheduler

from openrlhf.datasets import PromptDataset, SFTDataset
from openrlhf.models import Actor
from openrlhf.models.utils import compute_approx_kl, masked_mean
from openrlhf.trainer import BasePPOTrainer
from openrlhf.trainer.ppo_utils import Experience, RemoteExperienceMaker
from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
from openrlhf.utils import blending_datasets, get_tokenizer
from openrlhf.utils.deepspeed import DeepspeedStrategy
from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states
from openrlhf.utils.distributed_sampler import DistributedSampler
from openrlhf.utils.distributed_util import init_process_group, torch_dist_barrier_and_cuda_sync
from openrlhf.utils.logging_utils import init_logger
from openrlhf.utils.remote_rm_utils import remote_rm_fn_ray

logger = init_logger(__name__)

from .launcher import BasePPORole
from .utils import get_physical_gpu_id


class ActorPPOTrainer(BasePPOTrainer):
    def __init__(
        self,
        *args,
        vllm_engines: List = None,
        remote_rm_url: List[str] = None,
        critic_train_remote: bool = False,
        **kwargs,
    ):
        """PPOTrainer for ray.

        Args:
            vllm_engines (List, optional): vllm engines for text generation, if not specified, generate text by actor model directly. Defaults to None.
            critic_train_remote (bool, optional): whether this actor should triger corresponding critic model training. Defaults to False.
        """
        super().__init__(*args, **kwargs)
        self.remote_rm_url = remote_rm_url
        self.vllm_engines = vllm_engines
        self.critic_train_remote = critic_train_remote

        # wandb/tensorboard setting
        self._wandb = None
        self._tensorboard = None
        if self.strategy.args.use_wandb and self.strategy.is_rank_0():
            import wandb

            self._wandb = wandb
            if not wandb.api.api_key:
                wandb.login(key=self.strategy.args.use_wandb)
            wandb.init(
                entity=self.strategy.args.wandb_org,
                project=self.strategy.args.wandb_project,
                group=self.strategy.args.wandb_group,
                name=self.strategy.args.wandb_run_name,
                config=self.strategy.args.__dict__,
                reinit=True,
            )

            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/epoch")
            wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)

        # Initialize TensorBoard writer if wandb is not available
        if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard, self.strategy.args.wandb_run_name)
            self._tensorboard = SummaryWriter(log_dir=log_dir)

        self.experience_maker = RemoteExperienceMaker(
            self.actor,
            self.critic,
            self.reward_model,
            self.initial_model,
            self.tokenizer,
            self.prompt_max_len,
            self.kl_ctl,
            self.strategy,
            self.remote_rm_url,
            self.reward_fn,
            vllm_engines=self.vllm_engines,
            packing_samples=self.strategy.args.packing_samples,
        )

        backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl")
        self.use_cuda_ipc = False
        if backend == "nccl" and self.strategy.args.colocate_all_models:
            self.use_cuda_ipc = True

        # Create torch group with deepspeed rank 0 and all vllm ranks
        # to update vllm engine's weights after each training stage.
        #
        # Say we have 3 vllm engines and eache of them has 4 GPUs,
        # then the torch group is:
        # [    0,      1, 2, 3, 4,  5, 6, 7, 8,  9, 10, 11, 12]
        # |ds rank 0 |  engine-0  |  engine-1  |   engine-2   |
        #
        # For ZeRO-1/2:
        #   1. Broadcast parameters from rank 0 to all vllm engines
        # For ZeRO-3:
        #   1. AllGather paramters to rank 0
        #   2. Broadcast parameters from rank 0 to all vllm engines
        if self.vllm_engines is not None and not self.use_cuda_ipc and torch.distributed.get_rank() == 0:
            master_address = ray._private.services.get_node_ip_address()
            with socket.socket() as sock:
                sock.bind(("", 0))
                master_port = sock.getsockname()[1]

            vllm_num_engines, vllm_tensor_parallel_size = (
                self.strategy.args.vllm_num_engines,
                self.strategy.args.vllm_tensor_parallel_size,
            )
            world_size = vllm_num_engines * vllm_tensor_parallel_size + 1

            use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False)
            group_name = "openrlhf"
            refs = [
                engine.init_process_group.remote(
                    master_address,
                    master_port,
                    i * vllm_tensor_parallel_size + 1,
                    world_size,
                    group_name,
                    backend=backend,
                    use_ray=use_ray,
                )
                for i, engine in enumerate(self.vllm_engines)
            ]
            if use_ray:
                import ray.util.collective as collective

                collective.init_collective_group(world_size=world_size, rank=0, backend=backend, group_name=group_name)
                self._model_update_group = group_name
            else:
                self._model_update_group = init_process_group(
                    backend=backend,
                    init_method=f"tcp://{master_address}:{master_port}",
                    world_size=world_size,
                    rank=0,
                    group_name=group_name,
                )

            ray.get(refs)

        torch_dist_barrier_and_cuda_sync()

    def fit(
        self,
        args,
        prompts_dataloader,
        pretrain_dataloader,
        eval_dataloader,
        consumed_samples=0,
        num_update_steps_per_episodes=1,
    ) -> None:
        num_rollouts_per_episodes = (
            num_update_steps_per_episodes
            * args.train_batch_size
            // args.max_epochs
            // args.rollout_batch_size
            // args.n_samples_per_prompt
        )

        # get eval and save steps
        if args.eval_steps == -1:
            args.eval_steps = num_rollouts_per_episodes  # Evaluate once per epoch
        if args.save_steps == -1:
            args.save_steps = float("inf")  # do not save ckpt

        self.prompts_dataloader = prompts_dataloader
        self.pretrain_dataloader = pretrain_dataloader
        self.eval_dataloader = eval_dataloader

        # Restore step and start_epoch
        steps = consumed_samples // args.rollout_batch_size + 1
        start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
        consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

        for episode in range(start_episode, args.num_episodes):
            if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
                self.prompts_dataloader.sampler.set_epoch(
                    episode, consumed_samples=0 if episode > start_episode else consumed_samples
                )
            pbar = tqdm(
                range(self.prompts_dataloader.__len__()),
                desc=f"Episode [{episode + 1}/{args.num_episodes}]",
                disable=not self.strategy.is_rank_0(),
            )

            for _, rand_prompts, labels in self.prompts_dataloader:
                for i, experience in enumerate(
                    self.experience_maker.make_experience_list(rand_prompts, labels, **self.generate_kwargs)
                ):
                    if i == 0:
                        output = self.tokenizer.batch_decode(
                            experience.sequences[0].unsqueeze(0), skip_special_tokens=True
                        )
                        self.strategy.print(output)
                    self.replay_buffer.append(experience)

                if self.args.advantage_estimator not in ["group_norm", "dr_grpo"]:
                    self.replay_buffer.normalize(
                        self.strategy, "advantages", divide_by_std=not self.args.no_advantage_std_norm
                    )
                status = self.ppo_train(steps)
                self.replay_buffer.clear()

                if "kl" in status:
                    self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
                pbar.set_postfix(status)

                # logs/checkpoints
                client_states = {"consumed_samples": steps * args.rollout_batch_size}
                self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)

                pbar.update()
                steps = steps + 1

        if self._wandb is not None and self.strategy.is_rank_0():
            self._wandb.finish()
        if self._tensorboard is not None and self.strategy.is_rank_0():
            self._tensorboard.close()

    def ppo_train(self, global_steps):
        # 1. ensure all experience makers done
        self.experience_maker.flush()
        torch_dist_barrier_and_cuda_sync()
        status = {}

        # 2. triger remote critic model training
        if self.critic_train_remote:
            # sync for deepspeed_enable_sleep
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic.reload_states.remote())

            critic_status_ref = self.critic.fit.remote()

            if self.strategy.args.colocate_all_models or self.strategy.args.deepspeed_enable_sleep:
                status.update(ray.get(critic_status_ref))
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic.offload_states.remote())

        if self.strategy.args.colocate_all_models:
            torch_dist_barrier_and_cuda_sync()

        # 3. actor model training
        if global_steps > self.freezing_actor_steps:
            if self.strategy.args.deepspeed_enable_sleep:
                self.reload_states()

            status.update(self.ppo_train_actor(global_steps))

            if self.strategy.args.deepspeed_enable_sleep:
                self.offload_states()

            torch.cuda.empty_cache()

            # 4. broadcast weights to vllm engines
            if self.vllm_engines is not None:
                if self.strategy.args.vllm_enable_sleep:
                    batch_vllm_engine_call(self.vllm_engines, "wake_up")

                torch_dist_barrier_and_cuda_sync()
                self._broadcast_to_vllm()

                if self.strategy.args.vllm_enable_sleep:
                    batch_vllm_engine_call(self.vllm_engines, "sleep")
                    torch_dist_barrier_and_cuda_sync()

        # 5. wait remote critic model training done
        if self.critic_train_remote and not self.strategy.args.colocate_all_models:
            status.update(ray.get(critic_status_ref))
        torch_dist_barrier_and_cuda_sync()

        return status

    def ppo_train_actor(self, global_steps):
        torch.cuda.empty_cache()
        # replay buffer may be empty at first, we should rebuild at each training
        dataloader = DataLoader(
            self.replay_buffer,
            batch_size=self.replay_buffer.sample_batch_size,
            shuffle=False if self.strategy.ring_attn_group is not None else True,
            drop_last=True,
            pin_memory=self.dataloader_pin_memory,
            collate_fn=self.replay_buffer.collate_fn,
        )
        device = torch.cuda.current_device()

        status_list = []
        status_mean = {}
        for epoch in range(self.max_epochs):
            pbar = tqdm(
                dataloader,
                desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
                disable=not self.strategy.is_rank_0(),
            )
            for experience in pbar:
                experience.to_device(device)
                status = self.training_step(experience)

                # for DP
                # weighted mean for kl
                if "kl" in status:
                    status["kl"] *= status["response_length"]
                    status = self.strategy.all_reduce(status)
                    status["kl"] /= status["response_length"]

                short_status = {}

                if "policy_loss" in status:
                    short_status = {
                        "pg": status["policy_loss"],
                        "rm": status["reward"],
                        "ret": status["return"],
                        "glen": status["response_length"],
                        "tlen": status["total_length"],
                        "kl": status["kl"],
                        "act_lr": status["actor_lr"],
                    }

                if "critic_loss" in status:
                    short_status["cri"] = status["critic_loss"]
                    short_status["vals"] = status["values"]
                    short_status["cri_lr"] = status["critic_lr"]

                if "ptx_loss" in status:
                    short_status["ptx"] = status["ptx_loss"]

                status_list.append(status)
                pbar.set_postfix(short_status)

        if status_list:
            status_mean = status_list[0]
            for m in status_list[1:]:
                for k, v in m.items():
                    status_mean[k] += v
            for k in status_mean.keys():
                status_mean[k] /= len(status_list)
        torch.cuda.empty_cache()
        return status_mean

    def training_step(self, experience: Experience) -> Dict[str, float]:
        self.actor.train()

        sequences = experience.sequences
        action_mask = experience.action_mask
        attention_mask = experience.attention_mask
        packed_seq_lens = None
        old_action_log_probs = experience.action_log_probs
        advantages = experience.advantages
        base_action_log_probs = experience.base_action_log_probs

        # actor loss
        action_log_probs, output = self.actor(
            sequences,
            action_mask,
            attention_mask=attention_mask,
            return_output=True,
            ring_attn_group=self.strategy.ring_attn_group,
            packed_seq_lens=packed_seq_lens,
        )

        # loss function
        actor_loss = self.actor_loss_fn(
            action_log_probs,
            old_action_log_probs,
            advantages,
            action_mask=experience.action_mask,
        )

        if self.args.use_kl_loss:
            if self.initial_model is not None:
                kl = compute_approx_kl(
                    action_log_probs,
                    base_action_log_probs,
                    kl_estimator=self.args.kl_estimator,
                )
            else:
                kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
            kl_mean = masked_mean(kl, experience.action_mask, dim=-1)

            kl_loss = kl_mean.mean()
            experience.info["kl"] = kl_loss.item()
        else:
            kl_loss = 0

        # mixtral
        if self.aux_loss:
            aux_loss = output.aux_loss
        else:
            aux_loss = 0
        loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
        self.strategy.backward(loss, self.actor, self.actor_optim)

        # ptx loss
        if self.pretrain_dataloader is not None:
            data = next(self.pretrain_dataloader)
            inputs = data[1].squeeze(1).to(torch.cuda.current_device())
            attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
            label = torch.where(
                attention_mask.bool(),
                inputs,
                self.ptx_loss_fn.IGNORE_INDEX,
            )

            output = self.actor(inputs, attention_mask=attention_mask, return_output=True)
            ptx_log_probs = output["logits"]

            # loss function
            ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
            # mixtral
            if self.aux_loss:
                aux_loss = output.aux_loss
            else:
                aux_loss = 0
            loss = ptx_loss + aux_loss * self.args.aux_loss_coef
            self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)

        self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
        if self.ema_model:
            self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")

        # status
        status = {"policy_loss": actor_loss.detach().item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
        if self.pretrain_dataloader is not None:
            status["ptx_loss"] = ptx_loss.item()
        for k, v in experience.info.items():
            if k == "kl":
                status[k] = (
                    (v * experience.info["response_length"]).sum() / experience.info["response_length"].sum()
                ).item()
            else:
                status[k] = v.mean().item()
        return status

    def _broadcast_to_vllm(self):
        use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
        cache_reset_refs = []
        if use_prefix_cache and torch.distributed.get_rank() == 0:
            # clear prefix cache
            for engine in self.vllm_engines:
                cache_reset_refs.append(engine.reset_prefix_cache.remote())

        torch.cuda.empty_cache()
        model = self.actor.model.module
        count, num_params = 0, len(list(model.named_parameters()))
        for name, param in model.named_parameters():
            count += 1  # empty_cache at last param

            # broadcast
            if not self.use_cuda_ipc:
                use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False)
                # Fire all vllm engines for broadcast
                if torch.distributed.get_rank() == 0:
                    shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
                    refs = [
                        engine.update_weight.remote(
                            name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
                        )
                        for engine in self.vllm_engines
                    ]

                # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
                with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
                    if torch.distributed.get_rank() == 0:
                        if use_ray:
                            import ray.util.collective as collective

                            collective.broadcast(param.data, 0, group_name=self._model_update_group)
                        else:
                            torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
                        ray.get(refs)
            # CUDA IPC
            else:
                from torch.multiprocessing.reductions import reduce_tensor

                # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
                with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
                    weight = param.data.clone()
                    ipc_handle = reduce_tensor(weight)

                    ipc_handle = {get_physical_gpu_id(): ipc_handle}
                    ipc_handle_list = [None] * torch.distributed.get_world_size()
                    torch.distributed.all_gather_object(ipc_handle_list, ipc_handle)

                    if torch.distributed.get_rank() == 0:
                        ipc_handles = {}
                        for d in ipc_handle_list:
                            ipc_handles.update(d)

                        shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
                        refs = [
                            engine.update_weight_cuda_ipc.remote(
                                name,
                                dtype=param.dtype,
                                shape=shape,
                                ipc_handles=ipc_handles,
                                empty_cache=count == num_params,
                            )
                            for engine in self.vllm_engines
                        ]
                        ray.get(refs)
                    torch_dist_barrier_and_cuda_sync()

        if cache_reset_refs:
            ray.get(cache_reset_refs)
        torch.cuda.empty_cache()
        torch_dist_barrier_and_cuda_sync()

    def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
        if global_step % args.logging_steps == 0:
            # wandb
            if self._wandb is not None and self.strategy.is_rank_0():
                logs = {
                    "train/%s" % k: v
                    for k, v in {
                        **logs_dict,
                        "global_step": global_step,
                    }.items()
                }
                if self.experience_maker.perf_stats is not None:
                    logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()})
                self._wandb.log(logs)
            # TensorBoard
            elif self._tensorboard is not None and self.strategy.is_rank_0():
                for k, v in logs_dict.items():
                    self._tensorboard.add_scalar(f"train/{k}", v, global_step)
                if self.experience_maker.perf_stats is not None:
                    for k, v in self.experience_maker.perf_stats.items():
                        self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step)

        # TODO: Add evaluation mechanism for PPO
        if global_step % args.eval_steps == 0 and self.eval_dataloader and len(self.eval_dataloader) > 0:
            self.evaluate(self.eval_dataloader, global_step, args.eval_temperature, args.eval_n_samples_per_prompt)
        # save ckpt
        # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
        if global_step % args.save_steps == 0:
            tag = f"global_step{global_step}"
            self._save_checkpoint(args, tag, client_states)

    def _save_checkpoint(self, args, tag, client_states):
        # call remote critic
        if not self.disable_ds_ckpt:
            if self.critic_train_remote:
                ref = self.critic.save_checkpoint.remote(tag)
            self.strategy.save_ckpt(
                self.actor.model,
                os.path.join(args.ckpt_path, "_actor"),
                tag,
                args.max_ckpt_num,
                args.max_ckpt_mem,
                client_states,
            )
        if self.save_hf_ckpt:
            save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
            self.strategy.save_model(
                self.ema_model if args.enable_ema else self.actor,
                self.tokenizer,
                save_path,
            )
        # wait
        if not self.disable_ds_ckpt:
            if self.critic_train_remote:
                ray.get(ref)
        torch_dist_barrier_and_cuda_sync()

    def evaluate(self, eval_dataloader, global_step, temperature=0.6, n_samples_per_prompt=1):
        """Evaluate model performance on eval dataset.

        Args:
            eval_dataloader: DataLoader containing evaluation prompts, labels and data sources
            global_step: Current training step for logging
            n_samples_per_prompt: Number of samples to generate per prompt for pass@k calculation
        """
        start_time = time.time()
        if self.strategy.is_rank_0():
            logger.info(f"⏰ Evaluation start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # vLLM wakeup when vllm_enable_sleep
        if self.strategy.args.vllm_enable_sleep:
            from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

            batch_vllm_engine_call(self.vllm_engines, "wake_up")
            torch_dist_barrier_and_cuda_sync()

        # Only run evaluation on ring attention rank0
        if self.strategy.ring_attn_group is None or self.strategy.ring_attn_rank == 0:

            with torch.no_grad():
                # First collect all prompts and labels
                all_prompts = []
                all_labels = []
                all_datasources = []

                for datasources, prompts, labels in eval_dataloader:
                    all_prompts.extend(prompts)
                    all_labels.extend(labels)
                    all_datasources.extend(datasources)

                # Generate samples and calculate rewards
                generate_kwargs = self.generate_kwargs.copy()
                generate_kwargs["temperature"] = temperature
                generate_kwargs["n_samples_per_prompt"] = n_samples_per_prompt
                samples = self.experience_maker.generate_samples(all_prompts, all_labels, **generate_kwargs)
                queries = [self.tokenizer.batch_decode(seq, skip_special_tokens=False) for seq in samples.sequences]

                # duplicate prompts and labels for each sample
                all_prompts = sum([[prompt] * n_samples_per_prompt for prompt in all_prompts], [])
                all_labels = sum([[label] * n_samples_per_prompt for label in all_labels], [])

                # Calculate rewards
                if self.experience_maker.custom_reward_func:
                    rewards = self.experience_maker.custom_reward_func.remote(queries, all_prompts, all_labels)
                else:
                    rank = torch.distributed.get_rank() // self.strategy.ring_attn_size
                    rm = self.remote_rm_url[rank % len(self.remote_rm_url)]
                    rewards = remote_rm_fn_ray.remote(rm, queries=queries, prompts=all_prompts, labels=all_labels)
                rewards = ray.get(rewards)

                # Reshape rewards to (num_prompts, n_samples_per_prompt)
                rewards = rewards.reshape(-1, n_samples_per_prompt)

                # Collect local statistics for each data source
                local_metrics = {}  # {datasource: {"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}}

                for i, datasource in enumerate(all_datasources):
                    if datasource not in local_metrics:
                        local_metrics[datasource] = {f"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}

                    # Calculate pass@k and pass@1
                    prompt_rewards = rewards[i]
                    local_metrics[datasource][f"pass{n_samples_per_prompt}"] += prompt_rewards.max().float().item()
                    local_metrics[datasource]["pass1"] += prompt_rewards.mean().float().item()
                    local_metrics[datasource]["count"] += 1

                # All gather metrics from all ranks
                gathered_metrics = [None] * (self.strategy.world_size // self.strategy.ring_attn_size)
                if self.strategy.ring_attn_group is not None:
                    # Only rank 0 in ring attention group gathers metrics
                    torch.distributed.all_gather_object(
                        gathered_metrics, local_metrics, group=self.experience_maker.ring_rank0_group
                    )
                else:
                    torch.distributed.all_gather_object(gathered_metrics, local_metrics)

                # Only rank0 processes the gathered metrics
                if self.strategy.is_rank_0():
                    # Combine metrics from all ranks
                    global_metrics = {}
                    for rank_metrics in gathered_metrics:
                        for datasource, metrics in rank_metrics.items():
                            if datasource not in global_metrics:
                                global_metrics[datasource] = {f"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}
                            global_metrics[datasource][f"pass{n_samples_per_prompt}"] += metrics[
                                f"pass{n_samples_per_prompt}"
                            ]
                            global_metrics[datasource]["pass1"] += metrics["pass1"]
                            global_metrics[datasource]["count"] += metrics["count"]

                    # Calculate global averages
                    logs = {}
                    for datasource, metrics in global_metrics.items():
                        logs[f"eval_{datasource}_pass{n_samples_per_prompt}"] = (
                            metrics[f"pass{n_samples_per_prompt}"] / metrics["count"]
                        )
                        logs[f"eval_{datasource}_pass1"] = metrics["pass1"] / metrics["count"]

                    # Log to wandb/tensorboard
                    if self._wandb is not None:
                        logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": global_step}.items()}
                        self._wandb.log(logs)
                    elif self._tensorboard is not None:
                        for k, v in logs.items():
                            self._tensorboard.add_scalar(f"eval/{k}", v, global_step)

        if self.strategy.args.vllm_enable_sleep:
            batch_vllm_engine_call(self.vllm_engines, "sleep")

        torch.cuda.empty_cache()

        end_time = time.time()
        duration = end_time - start_time
        if self.strategy.is_rank_0():
            time_str = str(timedelta(seconds=duration)).split(".")[0]
            logger.info(f"✨ Evaluation completed in {time_str}")

    def reload_states(self):
        reload_deepspeed_states(self.actor.model)

    def offload_states(self):
        offload_deepspeed_states(self.actor.model)


@ray.remote(num_gpus=1)
class ActorModelRayActor(BasePPORole):
    def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain):
        args = strategy.args

        if getattr(args, "vllm_num_engines", 0) > 0:
            # To prevent hanging during NCCL synchronization of weights between DeepSpeed and vLLM.
            # see https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445
            if getattr(args, "vllm_sync_backend", "nccl") == "nccl":
                os.environ["NCCL_CUMEM_ENABLE"] = "0"

        self._setup_distributed(strategy)

        actor = Actor(
            pretrain,
            use_flash_attention_2=strategy.args.flash_attn,
            bf16=strategy.args.bf16,
            load_in_4bit=strategy.args.load_in_4bit,
            lora_rank=strategy.args.lora_rank,
            lora_alpha=strategy.args.lora_alpha,
            target_modules=strategy.args.target_modules,
            lora_dropout=strategy.args.lora_dropout,
            ds_config=strategy.get_ds_train_config(is_actor=True),
            packing_samples=strategy.args.packing_samples,
            temperature=strategy.args.temperature,
            use_liger_kernel=strategy.args.use_liger_kernel,
        )
        strategy.print(actor)

        # configure tokenizer
        self.tokenizer = get_tokenizer(
            pretrain, actor.model, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer
        )

        if args.enable_ema:
            ema_model = Actor(
                pretrain,
                use_flash_attention_2=strategy.args.flash_attn,
                bf16=strategy.args.bf16,
                load_in_4bit=strategy.args.load_in_4bit,
                ds_config=strategy.get_ds_eval_config(offload=True),
                packing_samples=strategy.args.packing_samples,
            )
        else:
            ema_model = None

        # configure optimizer
        actor_optim = strategy.create_optimizer(
            actor, lr=args.actor_learning_rate, betas=strategy.args.adam_betas, weight_decay=args.l2
        )

        # prepare_datasets
        self.prepare_datasets()

        # configure scheduler
        self.num_update_steps_per_episodes = (
            len(self.prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
        )
        max_steps = math.ceil(args.num_episodes * self.num_update_steps_per_episodes)
        self._max_steps = max_steps

        actor_scheduler = get_scheduler(
            "cosine_with_min_lr",
            actor_optim,
            num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
            num_training_steps=max_steps,
            scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1},
        )

        if args.gradient_checkpointing:
            actor.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
            )

        # prepare models/optimizers...
        self.actor, self.actor_optim, self.actor_scheduler = strategy.prepare(
            (actor, actor_optim, actor_scheduler),
            is_rlhf=True,
        )

        if ema_model:
            ema_model._offload = True
            self.ema_model = strategy.prepare(ema_model, is_rlhf=True)
        else:
            self.ema_model = None

        # load checkpoint
        self.consumed_samples = 0
        ckpt_path = os.path.join(args.ckpt_path, "_actor")
        if args.load_checkpoint and os.path.exists(ckpt_path):
            _, states = strategy.load_ckpt(self.actor.model, ckpt_path)
            self.consumed_samples = states["consumed_samples"]
            strategy.print(f"Loaded the checkpoint: {ckpt_path}, consumed_samples: {self.consumed_samples}")

        # initial offload
        if strategy.args.deepspeed_enable_sleep:
            offload_deepspeed_states(self.actor.model)

    def prepare_datasets(self):
        strategy = self.strategy
        args = self.strategy.args

        # prepare datasets
        train_data = blending_datasets(
            args.prompt_data,
            args.prompt_data_probs,
            strategy,
            args.seed,
            max_count=args.max_samples,
        )

        # Create train dataset
        train_data = train_data.select(range(min(args.max_samples, len(train_data))))
        self.prompts_dataset = PromptDataset(train_data, self.tokenizer, strategy, input_template=args.input_template)
        self.prompts_dataloader = strategy.setup_dataloader(
            self.prompts_dataset,
            args.rollout_batch_size // (strategy.world_size // strategy.ring_attn_size),
            True,
            True,
        )

        # Create eval dataset if eval data exists
        if getattr(args, "eval_dataset", None):
            eval_data = blending_datasets(
                args.eval_dataset,
                None,  # No probability sampling for eval datasets
                strategy,
            )
            eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
            eval_dataset = PromptDataset(eval_data, self.tokenizer, strategy, input_template=args.input_template)
            self.eval_dataloader = strategy.setup_dataloader(
                eval_dataset,
                1,
                True,
                False,
            )
        else:
            self.eval_dataloader = None

        if args.pretrain_data:
            pretrain_data = blending_datasets(
                args.pretrain_data,
                args.pretrain_data_probs,
                strategy,
                args.seed,
            )
            pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
            pretrain_dataset = SFTDataset(
                pretrain_data.select(
                    range(
                        min(
                            len(pretrain_data), args.max_epochs * len(self.prompts_dataset) * args.n_samples_per_prompt
                        )
                    )
                ),
                self.tokenizer,
                pretrain_max_len,
                strategy,
                pretrain_mode=True,
            )
            self.pretrain_dataloader = itertools.cycle(
                iter(
                    strategy.setup_dataloader(
                        pretrain_dataset,
                        args.micro_train_batch_size,
                        True,
                        True,
                        pretrain_dataset.collate_fn,
                    )
                )
            )
        else:
            self.pretrain_dataloader = None

    def max_steps(self):
        """Return the maximum number of steps."""
        return self._max_steps

    def fit(
        self,
        critic_model: ray.actor.ActorHandle,
        initial_model: ray.actor.ActorHandle,
        reward_model: List[ray.actor.ActorHandle],
        remote_rm_url: List[str] = None,
        reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
        vllm_engines: List[ray.actor.ActorHandle] = None,
        critic_train_remote: bool = False,
    ):
        """Train actor model with prompt datasets."""
        strategy = self.strategy
        args = self.strategy.args

        # configure Trainer
        trainer = ActorPPOTrainer(
            strategy,
            self.actor,
            critic_model,
            reward_model,
            initial_model,
            ema_model=self.ema_model,
            actor_optim=None,
            critic_optim=None,
            actor_scheduler=self.actor_scheduler,
            critic_scheduler=None,
            remote_rm_url=remote_rm_url,
            reward_fn=reward_fn,
            vllm_engines=vllm_engines,
            max_epochs=args.max_epochs,
            micro_train_batch_size=args.micro_train_batch_size,
            micro_rollout_batch_size=args.micro_rollout_batch_size,
            gradient_checkpointing=args.gradient_checkpointing,
            critic_train_remote=critic_train_remote,
            tokenizer=self.tokenizer,
            prompt_max_len=args.prompt_max_len,
            value_clip=args.value_clip,
            eps_clip=args.eps_clip,
            gamma=args.gamma,
            lambd=args.lambd,
            init_kl_coef=args.init_kl_coef,
            kl_target=args.kl_target,
            ema_beta=0.992,
            ptx_coef=args.ptx_coef,
            max_norm=args.max_norm,
            # for GPT generation
            do_sample=True,
            max_new_tokens=args.generate_max_len,
            max_length=args.max_len,
            temperature=args.temperature,
            top_p=args.top_p,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            save_hf_ckpt=args.save_hf_ckpt,
            disable_ds_ckpt=args.disable_ds_ckpt,
        )

        # broadcast checkpoint
        ckpt_path = os.path.join(args.ckpt_path, "_actor")
        if args.load_checkpoint and os.path.exists(ckpt_path) and not vllm_engines is None:
            # vLLM wakeup when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(vllm_engines, "wake_up")
            torch_dist_barrier_and_cuda_sync()

            trainer._broadcast_to_vllm()

            # vLLM offload when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(vllm_engines, "sleep")
                torch_dist_barrier_and_cuda_sync()

        trainer.fit(
            args,
            self.prompts_dataloader,
            self.pretrain_dataloader,
            self.eval_dataloader,
            self.consumed_samples,
            self.num_update_steps_per_episodes,
        )

    def save_model(self):
        args = self.strategy.args

        # save model checkpoint after fitting on only rank0
        self.strategy.save_model(
            self.ema_model if args.enable_ema else self.actor,
            self.tokenizer,
            args.save_path,
        )


# import itertools
# import math
# import os
# import socket
# from typing import Callable, Dict, List

# import deepspeed
# import ray
# import torch
# import torch.distributed
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers.trainer import get_scheduler

# from openrlhf.datasets import PromptDataset, SFTDataset
# from openrlhf.models import Actor
# from openrlhf.models.ring_attn_utils import pad_sequences, unpad_sequences
# from openrlhf.models.utils import compute_approx_kl, masked_mean, unpacking_samples
# from openrlhf.trainer import BasePPOTrainer
# from openrlhf.trainer.ppo_utils import Experience, RemoteExperienceMaker
# from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
# from openrlhf.utils import blending_datasets, get_tokenizer
# from openrlhf.utils.deepspeed import DeepspeedStrategy
# from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states
# from openrlhf.utils.distributed_sampler import DistributedSampler
# from openrlhf.utils.distributed_util import init_process_group

# from .launcher import BasePPORole
# from .utils import get_physical_gpu_id

# from openrlhf.utils.logging_utils import init_logger
# import time
# from openrlhf.utils.remote_rm_utils import remote_rm_fn_ray

# logger = init_logger(__name__)


# class ActorPPOTrainer(BasePPOTrainer):
#     def __init__(
#         self,
#         *args,
#         vllm_engines: List = None,
#         remote_rm_url: List[str] = None,
#         critic_train_remote: bool = False,
#         **kwargs,
#     ):
#         """PPOTrainer for ray.

#         Args:
#             vllm_engines (List, optional): vllm engines for text generation, if not specified, generate text by actor model directly. Defaults to None.
#             critic_train_remote (bool, optional): whether this actor should triger corresponding critic model training. Defaults to False.
#         """
#         super().__init__(*args, **kwargs)
#         self.remote_rm_url = remote_rm_url
#         self.vllm_engines = vllm_engines
#         self.critic_train_remote = critic_train_remote

#         # wandb/tensorboard setting
#         self._wandb = None
#         self._tensorboard = None
#         if self.strategy.args.use_wandb and self.strategy.is_rank_0():
#             import wandb

#             self._wandb = wandb
#             if not wandb.api.api_key:
#                 wandb.login(key=self.strategy.args.use_wandb)
#             wandb.init(
#                 entity=self.strategy.args.wandb_org,
#                 project=self.strategy.args.wandb_project,
#                 group=self.strategy.args.wandb_group,
#                 name=self.strategy.args.wandb_run_name,
#                 config=self.strategy.args.__dict__,
#                 reinit=True,
#             )

#             wandb.define_metric("train/global_step")
#             wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
#             wandb.define_metric("eval/epoch")
#             wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)

#         # Initialize TensorBoard writer if wandb is not available
#         if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
#             from torch.utils.tensorboard import SummaryWriter

#             os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
#             log_dir = os.path.join(self.strategy.args.use_tensorboard, self.strategy.args.wandb_run_name)
#             self._tensorboard = SummaryWriter(log_dir=log_dir)

#         self.experience_maker = RemoteExperienceMaker(
#             self.actor,
#             self.critic,
#             self.reward_model,
#             self.initial_model,
#             self.tokenizer,
#             self.prompt_max_len,
#             self.kl_ctl,
#             self.strategy,
#             self.remote_rm_url,
#             self.reward_fn,
#             vllm_engines=self.vllm_engines,
#             packing_samples=self.strategy.args.packing_samples,
#         )

#         backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl")
#         self.use_cuda_ipc = False
#         if backend == "nccl" and self.strategy.args.colocate_all_models:
#             self.use_cuda_ipc = True

#         # Create torch group with deepspeed rank 0 and all vllm ranks
#         # to update vllm engine's weights after each training stage.
#         #
#         # Say we have 3 vllm engines and eache of them has 4 GPUs,
#         # then the torch group is:
#         # [    0,      1, 2, 3, 4,  5, 6, 7, 8,  9, 10, 11, 12]
#         # |ds rank 0 |  engine-0  |  engine-1  |   engine-2   |
#         #
#         # For ZeRO-1/2:
#         #   1. Broadcast parameters from rank 0 to all vllm engines
#         # For ZeRO-3:
#         #   1. AllGather paramters to rank 0
#         #   2. Broadcast parameters from rank 0 to all vllm engines
#         if self.vllm_engines is not None and not self.use_cuda_ipc and torch.distributed.get_rank() == 0:
#             master_address = ray._private.services.get_node_ip_address()
#             with socket.socket() as sock:
#                 sock.bind(("", 0))
#                 master_port = sock.getsockname()[1]

#             vllm_num_engines, vllm_tensor_parallel_size = (
#                 self.strategy.args.vllm_num_engines,
#                 self.strategy.args.vllm_tensor_parallel_size,
#             )
#             world_size = vllm_num_engines * vllm_tensor_parallel_size + 1

#             use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False)
#             group_name = "openrlhf"
#             refs = [
#                 engine.init_process_group.remote(
#                     master_address,
#                     master_port,
#                     i * vllm_tensor_parallel_size + 1,
#                     world_size,
#                     group_name,
#                     backend=backend,
#                     use_ray=use_ray,
#                 )
#                 for i, engine in enumerate(self.vllm_engines)
#             ]
#             if use_ray:
#                 import ray.util.collective as collective

#                 collective.init_collective_group(world_size=world_size, rank=0, backend=backend, group_name=group_name)
#                 self._model_update_group = group_name
#             else:
#                 self._model_update_group = init_process_group(
#                     backend=backend,
#                     init_method=f"tcp://{master_address}:{master_port}",
#                     world_size=world_size,
#                     rank=0,
#                     group_name=group_name,
#                 )

#             ray.get(refs)

#         torch.distributed.barrier()

#     def fit(
#         self,
#         args,
#         prompts_dataloader,
#         pretrain_dataloader,
#         consumed_samples=0,
#         num_update_steps_per_episodes=1,
#     ) -> None:
#         num_rollouts_per_episodes = (
#             num_update_steps_per_episodes
#             * args.train_batch_size
#             // args.max_epochs
#             // args.rollout_batch_size
#             // args.n_samples_per_prompt
#         )

#         # get eval and save steps
#         if args.eval_steps == -1:
#             args.eval_steps = num_rollouts_per_episodes  # Evaluate once per epoch
#         if args.save_steps == -1:
#             args.save_steps = float("inf")  # do not save ckpt

#         self.prompts_dataloader = prompts_dataloader
#         self.pretrain_dataloader = pretrain_dataloader

#         # Restore step and start_epoch
#         steps = consumed_samples // args.rollout_batch_size + 1
#         start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
#         consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

#         for episode in range(start_episode, args.num_episodes):
#             if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
#                 self.prompts_dataloader.sampler.set_epoch(
#                     episode, consumed_samples=0 if episode > start_episode else consumed_samples
#                 )
#             pbar = tqdm(
#                 range(self.prompts_dataloader.__len__()),
#                 desc=f"Episode [{episode + 1}/{args.num_episodes}]",
#                 disable=not self.strategy.is_rank_0(),
#             )

#             for rand_prompts, labels in self.prompts_dataloader:

#                 for i, experience in enumerate(
#                     self.experience_maker.make_experience_list(rand_prompts, labels, **self.generate_kwargs)
#                 ):
#                     if i == 0:
#                         output = self.tokenizer.batch_decode(
#                             experience.sequences[0].unsqueeze(0), skip_special_tokens=True
#                         )
#                         self.strategy.print(output)
#                     self.replay_buffer.append(experience)

#                 if self.args.advantage_estimator not in ["group_norm", "dr_grpo"]:
#                     self.replay_buffer.normalize(
#                         self.strategy, "advantages", divide_by_std=not self.args.no_advantage_std_norm
#                     )
#                 status = self.ppo_train(steps)
#                 self.replay_buffer.clear()

#                 if "kl" in status:
#                     self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
#                 pbar.set_postfix(status)

#                 # logs/checkpoints
#                 client_states = {"consumed_samples": steps * args.rollout_batch_size}
#                 self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)

#                 pbar.update()
#                 steps = steps + 1

#         if self._wandb is not None and self.strategy.is_rank_0():
#             self._wandb.finish()
#         if self._tensorboard is not None and self.strategy.is_rank_0():
#             self._tensorboard.close()

#     def ppo_train(self, global_steps):
#         # 1. ensure all experience makers done
#         self.experience_maker.flush()
#         torch.distributed.barrier()
#         status = {}

#         # 2. triger remote critic model training
#         if self.critic_train_remote:
#             # sync for deepspeed_enable_sleep
#             if self.strategy.args.deepspeed_enable_sleep:
#                 ray.get(self.critic.reload_states.remote())

#             critic_status_ref = self.critic.fit.remote()

#             if self.strategy.args.colocate_all_models or self.strategy.args.deepspeed_enable_sleep:
#                 status.update(ray.get(critic_status_ref))
#             if self.strategy.args.deepspeed_enable_sleep:
#                 ray.get(self.critic.offload_states.remote())

#         if self.strategy.args.colocate_all_models:
#             torch.distributed.barrier()

#         # 3. actor model training
#         if global_steps > self.freezing_actor_steps:
#             if self.strategy.args.deepspeed_enable_sleep:
#                 self.reload_states()

#             status.update(self.ppo_train_actor(global_steps))

#             if self.strategy.args.deepspeed_enable_sleep:
#                 self.offload_states()

#             torch.cuda.empty_cache()

#             # 4. broadcast weights to vllm engines
#             if self.vllm_engines is not None:
#                 if self.strategy.args.vllm_enable_sleep:
#                     batch_vllm_engine_call(self.vllm_engines, "wake_up")

#                 torch.distributed.barrier()
#                 torch.cuda.synchronize()
#                 self._broadcast_to_vllm()

#                 if self.strategy.args.vllm_enable_sleep:
#                     batch_vllm_engine_call(self.vllm_engines, "sleep")
#                     torch.distributed.barrier()
#                     torch.cuda.synchronize()

#         # 5. wait remote critic model training done
#         if self.critic_train_remote and not self.strategy.args.colocate_all_models:
#             status.update(ray.get(critic_status_ref))
#         torch.distributed.barrier()

#         return status

#     def ppo_train_actor(self, global_steps):
#         torch.cuda.empty_cache()
#         # replay buffer may be empty at first, we should rebuild at each training
#         dataloader = DataLoader(
#             self.replay_buffer,
#             batch_size=self.replay_buffer.sample_batch_size,
#             shuffle=False if self.strategy.ring_attn_group is not None else True,
#             drop_last=True,
#             pin_memory=self.dataloader_pin_memory,
#             collate_fn=self.replay_buffer.collate_fn,
#         )
#         device = torch.cuda.current_device()

#         status_list = []
#         status_mean = {}
#         for epoch in range(self.max_epochs):
#             pbar = tqdm(
#                 dataloader,
#                 desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
#                 disable=not self.strategy.is_rank_0(),
#             )
#             for experience in pbar:
#                 experience.to_device(device)
#                 status = self.training_step(experience)

#                 # for DP
#                 # weighted mean for kl
#                 if "kl" in status:
#                     status["kl"] *= status["response_length"]
#                     status = self.strategy.all_reduce(status)
#                     status["kl"] /= status["response_length"]

#                 short_status = {}

#                 if "policy_loss" in status:
#                     short_status = {
#                         "pg": status["policy_loss"],
#                         "rm": status["reward"],
#                         "ret": status["return"],
#                         "glen": status["response_length"],
#                         "tlen": status["total_length"],
#                         "kl": status["kl"],
#                         "act_lr": status["actor_lr"],
#                     }

#                 if "critic_loss" in status:
#                     short_status["cri"] = status["critic_loss"]
#                     short_status["vals"] = status["values"]
#                     short_status["cri_lr"] = status["critic_lr"]

#                 if "ptx_loss" in status:
#                     short_status["ptx"] = status["ptx_loss"]

#                 status_list.append(status)
#                 pbar.set_postfix(short_status)

#         if status_list:
#             status_mean = status_list[0]
#             for m in status_list[1:]:
#                 for k, v in m.items():
#                     status_mean[k] += v
#             for k in status_mean.keys():
#                 status_mean[k] /= len(status_list)
#         torch.cuda.empty_cache()
#         return status_mean

#     def training_step(self, experience: Experience) -> Dict[str, float]:
#         self.actor.train()

#         # TODO: this is a bad indicator to say that data is packed...
#         if isinstance(experience.sequences, list):
#             sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
#             old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0)
#             advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0)
#             num_actions = [v.numel() for v in experience.advantages]
#             packed_seq_lens = [s.numel() for s in experience.sequences]
#             attention_mask = torch.cat(
#                 [torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0
#             ).unsqueeze(0)
#             # pad seq makes the sequence a multiple of ring_attention_size.
#             if self.strategy.ring_attn_group is not None:
#                 pad_len, sequences, attention_mask, num_actions, packed_seq_lens = pad_sequences(
#                     sequences, attention_mask, num_actions, packed_seq_lens, self.strategy.ring_attn_group
#                 )
#             if self.args.use_kl_loss and experience.base_action_log_probs is not None:
#                 base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0)
#         else:
#             sequences = experience.sequences
#             old_action_log_probs = experience.action_log_probs
#             advantages = experience.advantages
#             num_actions = experience.action_mask.size(1)
#             packed_seq_lens = None
#             attention_mask = experience.attention_mask
#             if self.args.use_kl_loss and experience.base_action_log_probs is not None:
#                 base_action_log_probs = experience.base_action_log_probs

#         # actor loss
#         action_log_probs, output = self.actor(
#             sequences,
#             num_actions,
#             attention_mask=attention_mask,
#             return_output=True,
#             ring_attn_group=self.strategy.ring_attn_group,
#             logps_allgather=True,
#             packed_seq_lens=packed_seq_lens,
#         )
#         # unpad sequence ensures that pad tokens do not contribute to the loss calculation.
#         if self.strategy.ring_attn_group is not None:
#             assert pad_len is not None
#             sequences, attention_mask, num_actions, packed_seq_lens, action_log_probs, _, _ = unpad_sequences(
#                 pad_len=pad_len,
#                 sequences=sequences,
#                 attention_mask=attention_mask,
#                 num_actions=num_actions,
#                 packed_seq_lens=packed_seq_lens,
#                 action_log_probs=action_log_probs,
#                 ring_attn_group=self.strategy.ring_attn_group,
#             )

#         # loss function
#         actor_loss = self.actor_loss_fn(
#             action_log_probs,
#             old_action_log_probs,
#             advantages,
#             action_mask=experience.action_mask,
#         )

#         if self.args.use_kl_loss:
#             if self.initial_model is not None:
#                 kl = compute_approx_kl(
#                     action_log_probs,
#                     base_action_log_probs,
#                     experience.action_mask,
#                     kl_estimator=self.args.kl_estimator,
#                 )
#             else:
#                 kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)

#             if not self.args.packing_samples:
#                 kl_mean = masked_mean(kl, experience.action_mask, dim=-1)
#             else:
#                 # convert tensor into list of tensors so that it's easier to manipulate
#                 # within dataset.

#                 kl = unpacking_samples(kl, num_actions)
#                 kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device)

#             kl_loss = kl_mean.mean()
#             experience.info["kl"] = kl_loss.item()
#         else:
#             kl_loss = 0

#         # mixtral
#         if self.aux_loss:
#             aux_loss = output.aux_loss
#         else:
#             aux_loss = 0
#         loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
#         self.strategy.backward(loss, self.actor, self.actor_optim)

#         # ptx loss
#         if self.pretrain_dataloader is not None:
#             data = next(self.pretrain_dataloader)
#             inputs = data[1].squeeze(1).to(torch.cuda.current_device())
#             attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
#             label = torch.where(
#                 attention_mask.bool(),
#                 inputs,
#                 self.ptx_loss_fn.IGNORE_INDEX,
#             )

#             output = self.actor(inputs, attention_mask=attention_mask, return_output=True)
#             ptx_log_probs = output["logits"]

#             # loss function
#             ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
#             # mixtral
#             if self.aux_loss:
#                 aux_loss = output.aux_loss
#             else:
#                 aux_loss = 0
#             loss = ptx_loss + aux_loss * self.args.aux_loss_coef
#             self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)

#         self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
#         if self.ema_model:
#             self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")

#         # status
#         status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
#         if self.pretrain_dataloader is not None:
#             status["ptx_loss"] = ptx_loss.item()
#         for k, v in experience.info.items():
#             if k == "kl":
#                 status[k] = (
#                     (v * experience.info["response_length"]).sum() / experience.info["response_length"].sum()
#                 ).item()
#             else:
#                 status[k] = v.mean().item()
#         return status

#     def _broadcast_to_vllm(self):
#         use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
#         cache_reset_refs = []
#         if use_prefix_cache and torch.distributed.get_rank() == 0:
#             # clear prefix cache
#             for engine in self.vllm_engines:
#                 cache_reset_refs.append(engine.reset_prefix_cache.remote())

#         torch.cuda.empty_cache()
#         model = self.actor.model.module
#         count, num_params = 0, len(list(model.named_parameters()))
#         for name, param in model.named_parameters():
#             count += 1  # empty_cache at last param

#             # broadcast
#             if not self.use_cuda_ipc:
#                 use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False)
#                 # Fire all vllm engines for broadcast
#                 if torch.distributed.get_rank() == 0:
#                     shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
#                     refs = [
#                         engine.update_weight.remote(
#                             name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
#                         )
#                         for engine in self.vllm_engines
#                     ]

#                 # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
#                 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
#                     if torch.distributed.get_rank() == 0:
#                         if use_ray:
#                             import ray.util.collective as collective

#                             collective.broadcast(param.data, 0, group_name=self._model_update_group)
#                         else:
#                             torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
#                         ray.get(refs)
#             # CUDA IPC
#             else:
#                 from torch.multiprocessing.reductions import reduce_tensor

#                 # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
#                 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
#                     weight = param.data.clone()
#                     ipc_handle = reduce_tensor(weight)

#                     ipc_handle = {get_physical_gpu_id(): ipc_handle}
#                     ipc_handle_list = [None] * torch.distributed.get_world_size()
#                     torch.distributed.all_gather_object(ipc_handle_list, ipc_handle)

#                     if torch.distributed.get_rank() == 0:
#                         ipc_handles = {}
#                         for d in ipc_handle_list:
#                             ipc_handles.update(d)

#                         shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
#                         refs = [
#                             engine.update_weight_cuda_ipc.remote(
#                                 name,
#                                 dtype=param.dtype,
#                                 shape=shape,
#                                 ipc_handles=ipc_handles,
#                                 empty_cache=count == num_params,
#                             )
#                             for engine in self.vllm_engines
#                         ]
#                         ray.get(refs)
#                     torch.distributed.barrier()
#                     torch.cuda.synchronize()

#         if cache_reset_refs:
#             ray.get(cache_reset_refs)
#         torch.cuda.empty_cache()
#         torch.distributed.barrier()

#     def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
#         if global_step % args.logging_steps == 0:
#             # wandb
#             if self._wandb is not None and self.strategy.is_rank_0():
#                 logs = {
#                     "train/%s" % k: v
#                     for k, v in {
#                         **logs_dict,
#                         "global_step": global_step,
#                     }.items()
#                 }
#                 if self.experience_maker.perf_stats is not None:
#                     logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()})
#                 self._wandb.log(logs)
#             # TensorBoard
#             elif self._tensorboard is not None and self.strategy.is_rank_0():
#                 for k, v in logs_dict.items():
#                     self._tensorboard.add_scalar(f"train/{k}", v, global_step)
#                 if self.experience_maker.perf_stats is not None:
#                     for k, v in self.experience_maker.perf_stats.items():
#                         self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step)

#         # TODO: Add evaluation mechanism for PPO
#         if global_step % args.eval_steps == 0 and self.eval_dataloader and len(self.eval_dataloader) > 0:
#             self.evaluate(self.eval_dataloader, global_step, args.eval_temperature, args.eval_n_samples_per_prompt) 
#             # pass
#         # save ckpt
#         # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
#         if global_step % args.save_steps == 0:
#             tag = f"global_step{global_step}"
#             self._save_checkpoint(args, tag, client_states)

#     def _save_checkpoint(self, args, tag, client_states):
#         # call remote critic
#         if not self.disable_ds_ckpt:
#             if self.critic_train_remote:
#                 ref = self.critic.save_checkpoint.remote(tag)
#             self.strategy.save_ckpt(
#                 self.actor.model,
#                 os.path.join(args.ckpt_path, "_actor"),
#                 tag,
#                 args.max_ckpt_num,
#                 args.max_ckpt_mem,
#                 client_states,
#             )
#         if self.save_hf_ckpt:
#             save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
#             self.strategy.save_model(
#                 self.ema_model if args.enable_ema else self.actor,
#                 self.tokenizer,
#                 save_path,
#             )
#         # wait
#         if not self.disable_ds_ckpt:
#             if self.critic_train_remote:
#                 ray.get(ref)
#         torch.distributed.barrier()

#     def reload_states(self):
#         reload_deepspeed_states(self.actor.model)

#     def offload_states(self):
#         offload_deepspeed_states(self.actor.model)




#     ###################################################################################################
#     #TODO: evaluation of PPO
#     def evaluate(self, eval_dataloader, global_step, temperature=0.6, n_samples_per_prompt=1):
#             """Evaluate model performance on eval dataset.

#             Args:
#                 eval_dataloader: DataLoader containing evaluation prompts, labels and data sources
#                 global_step: Current training step for logging
#                 n_samples_per_prompt: Number of samples to generate per prompt for pass@k calculation
#             """
#             start_time = time.time()
#             logger.info(f"⏰ Evaluation start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")

#             # vLLM wakeup when vllm_enable_sleep
#             if self.strategy.args.vllm_enable_sleep:
#                 from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

#                 batch_vllm_engine_call(self.vllm_engines, "wake_up")

#             with torch.no_grad():
#                 # First collect all prompts and labels
#                 all_prompts = []
#                 all_labels = []
#                 all_datasources = []

#                 for datasources, prompts, labels in eval_dataloader:
#                     all_prompts.extend(prompts)
#                     all_labels.extend(labels)
#                     all_datasources.extend(datasources)

#                 # Generate samples and calculate rewards
#                 generate_kwargs = self.generate_kwargs.copy()
#                 generate_kwargs["temperature"] = temperature
#                 generate_kwargs["n_samples_per_prompt"] = n_samples_per_prompt
#                 samples = self.experience_maker.generate_samples(all_prompts, all_labels, **generate_kwargs)
#                 queries_list = self.tokenizer.batch_decode(samples.sequences, skip_special_tokens=False)

#                 # duplicate prompts and labels for each sample
#                 all_prompts = sum([[prompt] * n_samples_per_prompt for prompt in all_prompts], [])
#                 all_labels = sum([[label] * n_samples_per_prompt for label in all_labels], [])

#                 # Calculate rewards
#                 if self.experience_maker.custom_reward_func:
#                     # Let Ray automatically distribute the workload across available resources
#                     batch_size = self.strategy.args.micro_rollout_batch_size
#                     num_chunks = (len(queries_list) + batch_size - 1) // batch_size
#                     r_refs = []
#                     for i in range(num_chunks):
#                         start_idx = i * batch_size
#                         end_idx = min((i + 1) * batch_size, len(queries_list))
#                         r = self.experience_maker.custom_reward_func.remote(
#                             queries_list[start_idx:end_idx],
#                             all_prompts[start_idx:end_idx],
#                             all_labels[start_idx:end_idx],
#                         )
#                         r_refs.append(r)
#                 else:
#                     # Distribute data across different remote reward function servers
#                     num_servers = len(self.remote_rm_url)
#                     batch_size = (len(queries_list) + num_servers - 1) // num_servers
#                     r_refs = []
#                     for i in range(num_servers):
#                         start_idx = i * batch_size
#                         end_idx = min((i + 1) * batch_size, len(queries_list))
#                         rm = self.remote_rm_url[i]
#                         r = remote_rm_fn_ray.remote(
#                             rm,
#                             queries=queries_list[start_idx:end_idx],
#                             prompts=all_prompts[start_idx:end_idx],
#                             labels=all_labels[start_idx:end_idx],
#                         )
#                         r_refs.append(r)

#                 # Reshape rewards to (num_prompts, n_samples_per_prompt)
#                 rewards = ray.get(r_refs)
#                 rewards = torch.cat(rewards, dim=0).reshape(-1, n_samples_per_prompt)

#                 # Collect local statistics for each data source
#                 global_metrics = {}  # {datasource: {"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}}

#                 for i, datasource in enumerate(all_datasources):
#                     if datasource not in global_metrics:
#                         global_metrics[datasource] = {f"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}

#                     # Calculate pass@k and pass@1
#                     prompt_rewards = rewards[i]
#                     if n_samples_per_prompt > 1:
#                         global_metrics[datasource][f"pass{n_samples_per_prompt}"] += prompt_rewards.max().float().item()
#                     global_metrics[datasource]["pass1"] += prompt_rewards.mean().float().item()
#                     global_metrics[datasource]["count"] += 1

#                 # Calculate global averages
#                 logs = {}
#                 for datasource, metrics in global_metrics.items():
#                     logs[f"eval_{datasource}_pass{n_samples_per_prompt}"] = (
#                         metrics[f"pass{n_samples_per_prompt}"] / metrics["count"]
#                     )
#                     logs[f"eval_{datasource}_pass1"] = metrics["pass1"] / metrics["count"]

#                 # Log to wandb/tensorboard
#                 if self._wandb is not None:
#                     logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": global_step}.items()}
#                     self._wandb.log(logs)
#                 elif self._tensorboard is not None:
#                     for k, v in logs.items():
#                         self._tensorboard.add_scalar(f"eval/{k}", v, global_step)

#             if self.strategy.args.vllm_enable_sleep:
#                 batch_vllm_engine_call(self.vllm_engines, "sleep")

#             end_time = time.time()
#             duration = end_time - start_time
#             time_str = str(timedelta(seconds=duration)).split(".")[0]
#             logger.info(f"✨ Evaluation completed in {time_str}")
#     ###################################################################################################













































# @ray.remote(num_gpus=1)
# class ActorModelRayActor(BasePPORole):
#     def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain):
#         args = strategy.args

#         if getattr(args, "vllm_num_engines", 0) > 0:
#             # To prevent hanging during NCCL synchronization of weights between DeepSpeed and vLLM.
#             # see https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445
#             if getattr(args, "vllm_sync_backend", "nccl") == "nccl":
#                 os.environ["NCCL_CUMEM_ENABLE"] = "0"

#         self._setup_distributed(strategy)

#         actor = Actor(
#             pretrain,
#             use_flash_attention_2=strategy.args.flash_attn,
#             bf16=strategy.args.bf16,
#             load_in_4bit=strategy.args.load_in_4bit,
#             lora_rank=strategy.args.lora_rank,
#             lora_alpha=strategy.args.lora_alpha,
#             target_modules=strategy.args.target_modules,
#             lora_dropout=strategy.args.lora_dropout,
#             ds_config=strategy.get_ds_train_config(is_actor=True),
#             packing_samples=strategy.args.packing_samples,
#             temperature=strategy.args.temperature,
#             use_liger_kernel=strategy.args.use_liger_kernel,
#         )
#         strategy.print(actor)

#         # configure tokenizer
#         self.tokenizer = get_tokenizer(
#             pretrain, actor.model, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer
#         )
#         # self.model.resize_token_embeddings(len(self.tokenizer))
#         #TODO: resize tokenizer & model, model = actor.model

#         if args.enable_ema:
#             ema_model = Actor(
#                 pretrain,
#                 use_flash_attention_2=strategy.args.flash_attn,
#                 bf16=strategy.args.bf16,
#                 load_in_4bit=strategy.args.load_in_4bit,
#                 ds_config=strategy.get_ds_eval_config(offload=True),
#                 packing_samples=strategy.args.packing_samples,
#             )
#         else:
#             ema_model = None

#         # configure optimizer
#         actor_optim = strategy.create_optimizer(
#             actor, lr=args.actor_learning_rate, betas=strategy.args.adam_betas, weight_decay=args.l2
#         )

#         # prepare_datasets
#         self.prepare_datasets()

#         # configure scheduler
#         self.num_update_steps_per_episodes = (
#             len(self.prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
#         )
#         max_steps = math.ceil(args.num_episodes * self.num_update_steps_per_episodes)
#         self._max_steps = max_steps

#         actor_scheduler = get_scheduler(
#             "cosine_with_min_lr",
#             actor_optim,
#             num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
#             num_training_steps=max_steps,
#             scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1},
#         )

#         if args.gradient_checkpointing:
#             actor.gradient_checkpointing_enable(
#                 gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
#             )

#         # prepare models/optimizers...
#         self.actor, self.actor_optim, self.actor_scheduler = strategy.prepare(
#             (actor, actor_optim, actor_scheduler),
#             is_rlhf=True,
#         )

#         if ema_model:
#             ema_model._offload = True
#             self.ema_model = strategy.prepare(ema_model, is_rlhf=True)
#         else:
#             self.ema_model = None

#         # load checkpoint
#         self.consumed_samples = 0
#         ckpt_path = os.path.join(args.ckpt_path, "_actor")
#         if args.load_checkpoint and os.path.exists(ckpt_path):
#             _, states = strategy.load_ckpt(self.actor.model, ckpt_path)
#             self.consumed_samples = states["consumed_samples"]
#             strategy.print(f"Loaded the checkpoint: {ckpt_path}, consumed_samples: {self.consumed_samples}")

#         # initial offload
#         if strategy.args.deepspeed_enable_sleep:
#             offload_deepspeed_states(self.actor.model)

#     def prepare_datasets(self):
#         strategy = self.strategy
#         args = self.strategy.args

#         # prepare datasets
#         prompts_data = blending_datasets(
#             args.prompt_data,
#             args.prompt_data_probs,
#             strategy,
#             args.seed,
#             max_count=args.max_samples,
#             return_eval=False,
#             train_split=args.prompt_split,
#         )
#         prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
#         self.prompts_dataset = PromptDataset(
#             prompts_data, self.tokenizer, strategy, input_template=args.input_template
#         )
#         self.prompts_dataloader = strategy.setup_dataloader(
#             self.prompts_dataset,
#             args.rollout_batch_size // (strategy.world_size // strategy.ring_attn_size),
#             True,
#             True,
#         )

#         if args.pretrain_data:
#             pretrain_data = blending_datasets(
#                 args.pretrain_data,
#                 args.pretrain_data_probs,
#                 strategy,
#                 args.seed,
#                 return_eval=False,
#                 train_split=args.pretrain_split,
#             )
#             pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
#             pretrain_dataset = SFTDataset(
#                 pretrain_data.select(
#                     range(
#                         min(
#                             len(pretrain_data), args.max_epochs * len(self.prompts_dataset) * args.n_samples_per_prompt
#                         )
#                     )
#                 ),
#                 self.tokenizer,
#                 pretrain_max_len,
#                 strategy,
#                 pretrain_mode=True,
#             )
#             self.pretrain_dataloader = itertools.cycle(
#                 iter(
#                     strategy.setup_dataloader(
#                         pretrain_dataset,
#                         args.micro_train_batch_size,
#                         True,
#                         True,
#                         pretrain_dataset.collate_fn,
#                     )
#                 )
#             )
#         else:
#             self.pretrain_dataloader = None

#     def max_steps(self):
#         """Return the maximum number of steps."""
#         return self._max_steps

#     def fit(
#         self,
#         critic_model: ray.actor.ActorHandle,
#         initial_model: ray.actor.ActorHandle,
#         reward_model: List[ray.actor.ActorHandle],
#         remote_rm_url: List[str] = None,
#         reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
#         vllm_engines: List[ray.actor.ActorHandle] = None,
#         critic_train_remote: bool = False,
#     ):
#         """Train actor model with prompt datasets."""
#         strategy = self.strategy
#         args = self.strategy.args

#         # configure Trainer
#         trainer = ActorPPOTrainer(
#             strategy,
#             self.actor,
#             critic_model,
#             reward_model,
#             initial_model,
#             ema_model=self.ema_model,
#             actor_optim=None,
#             critic_optim=None,
#             actor_scheduler=self.actor_scheduler,
#             critic_scheduler=None,
#             remote_rm_url=remote_rm_url,
#             reward_fn=reward_fn,
#             vllm_engines=vllm_engines,
#             max_epochs=args.max_epochs,
#             micro_train_batch_size=args.micro_train_batch_size,
#             micro_rollout_batch_size=args.micro_rollout_batch_size,
#             gradient_checkpointing=args.gradient_checkpointing,
#             critic_train_remote=critic_train_remote,
#             tokenizer=self.tokenizer,
#             prompt_max_len=args.prompt_max_len,
#             value_clip=args.value_clip,
#             eps_clip=args.eps_clip,
#             gamma=args.gamma,
#             lambd=args.lambd,
#             init_kl_coef=args.init_kl_coef,
#             kl_target=args.kl_target,
#             ema_beta=0.992,
#             ptx_coef=args.ptx_coef,
#             max_norm=args.max_norm,
#             # for GPT generation
#             do_sample=True,
#             max_new_tokens=args.generate_max_len,
#             max_length=args.max_len,
#             temperature=args.temperature,
#             top_p=args.top_p,
#             pad_token_id=self.tokenizer.pad_token_id,
#             eos_token_id=self.tokenizer.eos_token_id,
#             save_hf_ckpt=args.save_hf_ckpt,
#             disable_ds_ckpt=args.disable_ds_ckpt,
#         )

#         # broadcast checkpoint
#         ckpt_path = os.path.join(args.ckpt_path, "_actor")
#         if args.load_checkpoint and os.path.exists(ckpt_path) and not vllm_engines is None:
#             # vLLM wakeup when vllm_enable_sleep
#             if self.strategy.args.vllm_enable_sleep:
#                 batch_vllm_engine_call(vllm_engines, "wake_up")
#             torch.distributed.barrier()
#             torch.cuda.synchronize()

#             trainer._broadcast_to_vllm()

#             # vLLM offload when vllm_enable_sleep
#             if self.strategy.args.vllm_enable_sleep:
#                 batch_vllm_engine_call(vllm_engines, "sleep")
#                 torch.distributed.barrier()
#                 torch.cuda.synchronize()

#         trainer.fit(
#             args,
#             self.prompts_dataloader,
#             self.pretrain_dataloader,
#             self.consumed_samples,
#             self.num_update_steps_per_episodes,
#         )

#     def save_model(self):
#         args = self.strategy.args

#         # save model checkpoint after fitting on only rank0
#         self.strategy.save_model(
#             self.ema_model if args.enable_ema else self.actor,
#             self.tokenizer,
#             args.save_path,
#         )





