"""
Implementation of fully async training in SkyRL.

For details, see https://skyrl.readthedocs.io/en/latest/tutorials/fully_async.html.

High-level notes:
- The global_step in each training loop iteration denotes the "current step being worked on", so
`global_step - 1` is the number of steps the model has finished training.
- We do not do any cross-epoch asynchrony here, so all the control logics like
  generation workers and data buffer are initialized per-epoch. The async dataloader
  and staleness manager are also reset / validated at the end of each epoch.
"""

import asyncio
import os
import torch
import traceback
import sys
from loguru import logger
from skyrl_train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl_train.utils import Timer
from skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.generators.base import GeneratorOutput
from skyrl_train.utils.trainer_utils import ResumeMode, build_dataloader
from skyrl_train.utils.io import io
from skyrl_train.generators.utils import prepare_generator_input, concatenate_generator_outputs
from skyrl_train.inference_engines.utils import get_sampling_params_for_backend
from dataclasses import dataclass
from torchdata.stateful_dataloader import StatefulDataLoader
from typing import List, Tuple, Iterable, Set
import inspect


@dataclass
class GeneratedOutputGroup:
    """
    The GeneratorOutput for a single group of rollouts, along with the metadata.

    Attributes:
        generator_output (GeneratorOutput): The GeneratorOutput for a single group of rollouts.
            That is, the output to the same prompt, but `n_samples_per_prompt` of it.

        uid (str): The uid of the group. Underlyingly, it is the index of the data in train_dataloader.dataset.

        global_step_when_scheduled (int): The global step when the group was scheduled for generation,
            used for validating the staleness control.
    """

    generator_output: GeneratorOutput
    uid: str
    global_step_when_scheduled: int


@dataclass
class _RolloutStat:
    """
    Global statistics of the trajectories used for staleness control in `_AsyncStalenessManager`.

    Note that these statistics are not per-epoch, but accumulates across all epochs.

    Attributes:
        submitted (int): The number of trajectories submitted to all generation workers, only used
            for logging purposes (e.g. compute submitted / accepted ratio to see how many
            trajectories failed). This is strictly increasing.
        accepted (int): The number of trajectories finished generation (can be either consumed by,
            or about to be consumed by the training worker). This is strictly increasing.
        running (int): The number of trajectories currently being generated by the generation workers.

    For details, see https://skyrl.readthedocs.io/en/latest/tutorials/fully_async.html#async-staleness-manager
    """

    submitted: int = 0
    accepted: int = 0
    running: int = 0


class _AsyncStalenessManager:
    """
    A controller that manages the capacity of the generation workers based on staleness control.

    The goal is to never submit more trajectories to the generation workers than the training worker
    can consume, so that the trajectories are not too stale (relative to max_staleness_steps).
    This is enforced via a capacity rule, not a hard **per-group** staleness guarantee: we bound
    the **aggregate** number of groups that can be ahead of training so that, in **steady state**,
    staleness remains within the configured budget of `max_staleness_steps`.

    In pathological cases (e.g., very long-running trajectories), an individual group may take
    more than `max_staleness_steps` of training steps of time to finish generation. For such rare
    cases, we still accept the trajectory and log the staleness metrics with a warning.

    The key capacity formula is implemented in `_compute_capacity_unlocked`. For details and caveats,
    see https://skyrl.readthedocs.io/en/latest/tutorials/fully_async.html#async-staleness-manager.

    Reference:
    - Modeled after AReal's StalenessManager: https://github.com/inclusionAI/AReaL/blob/b755c4447c2fff97889d8828293ee85f17a806f9/areal/core/staleness_manager.py
    - The idea of this controller is from section 5.1 of AReal's paper: https://arxiv.org/pdf/2505.24298v3
    """

    def __init__(self, max_concurrent_generation_groups: int, mini_batch_size: int, max_staleness_steps: int):
        self.max_concurrent_generation_groups = max_concurrent_generation_groups
        self.mini_batch_size = mini_batch_size
        self.max_staleness_steps = max_staleness_steps

        # Control logics.
        self._stat = _RolloutStat()
        self._cond = asyncio.Condition()

        # The current version that is being worked on, i.e. FullyAsyncRayPPOTrainer.global_step.
        # `self._current_global_step - 1` is the number of steps the model has finished training.
        self._current_global_step = 1

    def load_state_from_checkpoint(self, global_step: int) -> None:
        """
        Load the state from a checkpoint.
        """
        self._current_global_step = global_step
        # trainer has already consumed (and hence submitted) this many trajectories.
        self._stat.accepted = (global_step - 1) * self.mini_batch_size
        self._stat.submitted = self._stat.accepted

    async def validate_state_at_epoch_end(self, global_step: int) -> None:
        """
        Check that the current version and accepted rollouts are consistent with the global step.

        Args:
            global_step: The global step we are about to train on (after incrementing).
        """
        async with self._cond:
            assert self._stat.running == 0, "We expect no rollouts are running at end of an epoch."
            assert (
                self._stat.submitted == self._stat.accepted
            ), "We expect all submitted rollouts to be accepted at end of an epoch."
            consumed = (global_step - 1) * self.mini_batch_size
            assert (
                self._stat.accepted == consumed
            ), f"Unexpected number of accepted rollouts. Got {self._stat.accepted} != {consumed}."
            assert (
                self._current_global_step == global_step
            ), f"Unexpected current version. Got {self._current_global_step} != {global_step}."
            assert self._stat.submitted == self._stat.accepted, (
                "We expect all submitted rollouts to be accepted at end of an epoch. "
                f"Got {self._stat.submitted} != {self._stat.accepted}."
            )

    def _compute_capacity_unlocked(self) -> int:
        # NOTE(Charlie): do not need a self._current_global_step + 1 here unlike AReal because our
        # `_current_global_step` is "the version being worked on", not already finished steps.
        consumer_capacity = (self.max_staleness_steps + self._current_global_step) * self.mini_batch_size
        producer_staleness_capacity = consumer_capacity - (self._stat.accepted + self._stat.running)
        producer_concurrency_capacity = self.max_concurrent_generation_groups - self._stat.running
        return min(producer_concurrency_capacity, producer_staleness_capacity)

    async def acquire_submission_slot(self) -> None:
        """Block until there is capacity, then reserve a slot (increments submitted and running).

        This method always uses the latest current version tracked internally, which is
        updated by `notify_capacity_change(new_global_step)`.
        """
        async with self._cond:
            while self._compute_capacity_unlocked() <= 0:
                await self._cond.wait()
            # Reserve slot
            self._stat.submitted += 1
            self._stat.running += 1

    async def on_rollout_accepted(self) -> None:
        async with self._cond:
            self._stat.accepted += 1
            self._stat.running -= 1
            self._cond.notify_all()

    async def on_rollout_rejected(self) -> None:
        """
        Called when a generation is not accepted, or generation worker runs into error while generating a trajectory.

        Currently, we do not call this method but instead raise errors. We might need to use this when we want to
        filter out trajectories.
        """
        async with self._cond:
            self._stat.running -= 1
            self._cond.notify_all()

    async def notify_capacity_change(self, new_global_step: int) -> None:
        # Called when current_global_step changes (e.g., after a training step)
        async with self._cond:
            self._current_global_step = int(new_global_step)
            self._cond.notify_all()


class _AsyncDataloader:
    """
    A train dataloader wrapper that accommodates the need of fully async training, including:
    - Thread-safe dataloader iteration with a lock, as there are multiple parallel generation workers polling data.
    - Records consumed data UIDs for checkpointing to avoid training on the same data upon resuming.
    - Set the effective dataloader length to be divisible by mini-batch size, since we cannot rely on `drop_last`
      because the batch size is 1 in fully async training.
    """

    def __init__(self, train_dataloader: StatefulDataLoader, mini_batch_size: int):
        self._train_dataloader = train_dataloader
        self._train_dataloader_initial_state = train_dataloader.state_dict()
        self._effective_dataloader_length = len(self._train_dataloader) // mini_batch_size * mini_batch_size
        self._iter = enumerate(self._train_dataloader)
        self._lock: asyncio.Lock = asyncio.Lock()
        self._consumed_data_uids: Set[str] = set()
        self._exhausted: bool = False  # currently not used.

    def load_state_from_checkpoint(self, consumed_data_uids_set: Set[str]) -> None:
        """
        Load the state from a checkpoint.
        """
        self._consumed_data_uids = consumed_data_uids_set

        # Reset in case the dataloader loaded the state from the checkpoint, which we do not want.
        self._train_dataloader.load_state_dict(self._train_dataloader_initial_state)

    async def reset_at_epoch_end(self) -> None:
        async with self._lock:
            self._train_dataloader.load_state_dict(self._train_dataloader_initial_state)  # reset to initial state
            self._iter = enumerate(self._train_dataloader)
            self._consumed_data_uids.clear()
            self._exhausted = False

    async def get_next_non_consumed_data(self):
        """
        Return the next batch of training data.

        If we loaded from a checkpoint, it will skip the already-consumed data. Returns None if the dataloader is exhausted.
        """
        assert self._iter is not None and self._lock is not None, "Dataloader not initialized; call reset() first"
        async with self._lock:
            try:
                while True:
                    # Keep polling until we get a non-consumed data or the dataloader is exhausted.
                    iter_idx, rand_prompts = next(self._iter)
                    if iter_idx >= self._effective_dataloader_length:
                        raise StopIteration
                    uid = rand_prompts[0]["uid"]
                    if uid not in self._consumed_data_uids:
                        return rand_prompts
            except StopIteration:
                self._exhausted = True
                return None

    async def mark_consumed_uids(self, uids: Iterable[str]) -> None:
        assert self._lock is not None, "Dataloader not initialized; call reset() first"
        async with self._lock:
            for uid in uids:
                assert uid not in self._consumed_data_uids, "Duplicate UID found in mini-batch"
                self._consumed_data_uids.add(uid)

    def get_consumed_uids_list(self) -> List[str]:
        return list(self._consumed_data_uids)


class FullyAsyncRayPPOTrainer(RayPPOTrainer):

    def __init__(self, *args, **kwargs):
        # Extract cfg before base init so we can initialize async-specific knobs used by our overrides.
        cfg = kwargs.get("cfg", args[0] if len(args) > 0 else None)
        assert cfg is not None, "cfg must be provided to FullyAsyncRayPPOTrainer"

        # Initialize async-specific knobs
        self.num_parallel_generation_workers = cfg.trainer.fully_async.num_parallel_generation_workers
        self.mini_batch_size = cfg.trainer.policy_mini_batch_size
        self.max_staleness_steps = cfg.trainer.fully_async.max_staleness_steps

        assert (
            # otherwise wasted throughput
            self.mini_batch_size <= self.num_parallel_generation_workers
            and
            # otherwise would never use all workers due to capacity constraint
            self.num_parallel_generation_workers <= self.mini_batch_size * (self.max_staleness_steps + 1)
        ), (
            "Invalid num_parallel_generation_workers, the following must hold: "
            "mini_batch_size <= num_parallel_generation_workers <= mini_batch_size * (max_staleness_steps + 1). Got: "
            f"{self.mini_batch_size=}, {self.num_parallel_generation_workers=}, {self.max_staleness_steps=}"
        )

        # Initialize base trainer
        super().__init__(*args, **kwargs)

        # Some async-specific validations
        assert (
            self.cfg.trainer.train_batch_size == self.cfg.trainer.policy_mini_batch_size
        ), "train_batch_size must equal policy_mini_batch_size for fully async training"
        assert (
            self.cfg.trainer.algorithm.dynamic_sampling.type is None
        ), "dynamic sampling is not supported for fully async training yet."
        assert (
            not self.cfg.generator.batched
        ), "batched is not supported for fully async training since a batched generate() call does not support pause/continue."
        assert self.cfg.generator.async_engine, "async_engine must be True for fully async training."
        # TODO(Charlie): we can support it, just multi-turn partial rollout but synchronous.
        assert not self.colocate_all, "colocate_all is not supported for async training yet."

        # TODO(Charlie): need to assert we are doing TIS and returning logprobs

        # Async-specific states
        self.async_train_dataloader = _AsyncDataloader(self.train_dataloader, self.mini_batch_size)
        self._staleness_manager = _AsyncStalenessManager(
            max_concurrent_generation_groups=self.num_parallel_generation_workers,
            mini_batch_size=self.mini_batch_size,
            max_staleness_steps=self.max_staleness_steps,
        )

    def _build_train_dataloader_and_compute_training_steps(self):
        """
        Overrides to build dataloader for fully async training. See `_AsyncDataloader` for more details.
        """
        self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True, is_fully_async=True)
        self.num_steps_per_epoch = len(self.train_dataloader) // self.mini_batch_size
        self.total_training_steps = self.num_steps_per_epoch * self.cfg.trainer.epochs
        logger.info(f"Length of train_dataloader: {len(self.train_dataloader)}")
        logger.info(f"Number of steps per epoch: {self.num_steps_per_epoch}")
        logger.info(f"Total training steps: {self.total_training_steps}")

    async def train(self):
        """
        Main fully async training loop for PPO
        """
        self.global_step = 0

        # Load checkpoint state if resumption is enabled. Also load the data UIDs that are already trained on.
        if self.resume_mode != ResumeMode.NONE:
            with Timer("load_checkpoints"):
                self.global_step, _, loaded_consumed_data_uids_set = self.load_checkpoints()
                logger.info(f"Resumed training from global_step {self.global_step}")
                if self.global_step > 0:
                    # Set async dataloader manager and staleness manager to the loaded state.
                    self.async_train_dataloader.load_state_from_checkpoint(loaded_consumed_data_uids_set)
                    self._staleness_manager.load_state_from_checkpoint(
                        self.global_step + 1
                    )  # +1 due to we haven't incremented yet
                    steps_completed_in_epoch = self.global_step % self.num_steps_per_epoch
                    if steps_completed_in_epoch == 0 and len(loaded_consumed_data_uids_set) > 0:
                        # When resuming mid-epoch at the boundary, treat modulo 0 as a full epoch.
                        steps_completed_in_epoch = self.num_steps_per_epoch
                    expected_consumed_in_epoch = self.mini_batch_size * steps_completed_in_epoch
                    assert len(loaded_consumed_data_uids_set) == expected_consumed_in_epoch, (
                        "Unexpected number of consumed data UIDs. Got: "
                        f"{len(loaded_consumed_data_uids_set)} != {expected_consumed_in_epoch}"
                    )

        # Initialize weight sync state
        with Timer("init_weight_sync_state"):
            self.init_weight_sync_state()

        # sync weights to inference engines
        with Timer("sync_weights_to_inference_engines"):
            await self.async_sync_policy_weights_to_inference_engines()

        # Eval before training
        if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
            with Timer("eval", self.all_timings):
                eval_metrics = await self.eval()
                self.tracker.log(eval_metrics, step=self.global_step)

        # main training loop
        pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Step Progress")
        start_epoch = self.global_step // self.num_steps_per_epoch
        self.global_step += 1  # start training at global_step 1
        for epoch in range(start_epoch, self.cfg.trainer.epochs):
            # 0. Per-epoch prologue. Note that we do not do any cross-epoch asynchrony here.

            # Buffer of completed generation, size bounded by capacity - consumed = B * (max_staleness_steps + 1)
            generation_output_group_buffer = asyncio.Queue[GeneratedOutputGroup](
                maxsize=self.mini_batch_size * (self.max_staleness_steps + 1)
            )

            # Maintain self.num_parallel_generation_workers concurrent group-generation workers
            generator_tasks = [
                asyncio.create_task(self._run_generate_for_a_group_loop(generation_output_group_buffer))
                for _ in range(self.num_parallel_generation_workers)
            ]

            for _ in range(self.global_step, (1 + epoch) * self.num_steps_per_epoch + 1):
                with Timer("step", self.all_timings):
                    # 1. Wait until we have enough groups buffered.
                    cur_generation_group_mini_batch: List[GeneratedOutputGroup] = []
                    with Timer("wait_for_generation_buffer", self.all_timings):
                        buffer_pbar = tqdm(
                            total=self.mini_batch_size,
                            initial=0,
                            desc="Generation Buffer Progress",
                            position=1,
                        )
                        # NOTE(Charlie): we currently trim the train_dataloader to make it perfectly divisible by
                        # self.mini_batch_size, and assume that all trajectories succeed (just like sync training),
                        # so we always get a full mini-batch. Otherwise (e.g. want to drop stale trajectories), we
                        # should handle the case where the dataloader is exhausted and the buffer is empty, or
                        # else this loop will never exit.
                        while len(cur_generation_group_mini_batch) < self.mini_batch_size:
                            # We do finish-time FIFO here (not schedule-time FIFO)
                            cur_generation_group_mini_batch.append(await generation_output_group_buffer.get())
                            buffer_pbar.update(1)
                            buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()})
                        buffer_pbar.close()

                    # 2. Post-process the generated groups, aggregating to a single GeneratorOutput, and convert to training format.
                    with Timer("convert_to_training_input", self.all_timings):
                        training_input = await asyncio.to_thread(
                            self.convert_generation_group_mini_batch_to_training_input, cur_generation_group_mini_batch
                        )

                    # 3. Run training and update consumed UIDs.
                    with Timer("run_training", self.all_timings):
                        status = await self._run_training(training_input)
                        await self.async_train_dataloader.mark_consumed_uids(
                            [g.uid for g in cur_generation_group_mini_batch]
                        )

                    # 4. After training: pause generation, sync weights, resume.
                    with Timer("sync_weights", self.all_timings):
                        await self.inference_engine_client.pause_generation()
                        await self.async_sync_policy_weights_to_inference_engines()
                        await self.inference_engine_client.resume_generation()

                # 5. Set logs for this training step.
                logger.info(status)
                self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
                self.tracker.log(self.all_metrics, step=self.global_step)
                self.all_metrics = {}
                pbar.update(1)

                # 6. Eval and checkpointing if needed.
                # NOTE(Charlie): eval does not overlap with training, but can overlap with generation. Is it fine?
                if self.cfg.trainer.eval_interval > 0 and (
                    self.global_step % self.cfg.trainer.eval_interval == 0
                    or self.global_step == self.total_training_steps
                ):
                    with Timer("eval", self.all_timings):
                        eval_metrics = await self.eval()
                        self.all_metrics.update(eval_metrics)
                if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
                    with Timer("save_checkpoints", self.all_timings):
                        await asyncio.to_thread(self.save_checkpoints)
                if self.cfg.trainer.hf_save_interval > 0 and self.global_step % self.cfg.trainer.hf_save_interval == 0:
                    with Timer("save_hf_model", self.all_timings):
                        await asyncio.to_thread(self.save_models)
                self.tracker.log({"timing/" + k: v for k, v in self.all_timings.items()}, step=self.global_step)
                self.all_timings = {}
                self.global_step += 1

                # 7. Notify generation workers that the capacity has increased, unblocking them.
                await self._staleness_manager.notify_capacity_change(self.global_step)
                steps_completed_in_epoch = (self.global_step - 1) % self.num_steps_per_epoch
                if steps_completed_in_epoch == 0:
                    # At the end of an epoch, modulo becomes 0 but we've consumed a full epoch.
                    steps_completed_in_epoch = self.num_steps_per_epoch
                expected_consumed_in_epoch = self.mini_batch_size * steps_completed_in_epoch
                actual_consumed_in_epoch = len(self.async_train_dataloader.get_consumed_uids_list())
                assert actual_consumed_in_epoch == expected_consumed_in_epoch, (
                    "Unexpected number of consumed data UIDs. Got: "
                    f"{actual_consumed_in_epoch} != {expected_consumed_in_epoch}"
                )

            # 8. Per-epoch epilogue.
            if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None:
                with Timer("update_ref_with_policy", self.all_timings):
                    await asyncio.to_thread(self.update_ref_with_policy)

            # Cancel generator tasks for this epoch
            for t in generator_tasks:
                t.cancel()
            try:
                await asyncio.gather(*generator_tasks, return_exceptions=True)
            except Exception:
                pass

            # Per-epoch reset/validation for data loading and staleness management
            assert all(
                t.done() for t in generator_tasks
            ), "Generator tasks must be done before resetting the dataloader manager and validating the staleness manager."
            assert (
                generation_output_group_buffer.qsize() == 0
            ), f"We expect all generation output to be consumed by the training worker at end of an epoch, got {generation_output_group_buffer.qsize()}."
            await self.async_train_dataloader.reset_at_epoch_end()
            await self._staleness_manager.validate_state_at_epoch_end(self.global_step)

            # End of an epoch.
        pbar.close()
        if self.cfg.trainer.ckpt_interval > 0:
            with Timer("save_checkpoints", self.all_timings):
                await asyncio.to_thread(self.save_checkpoints)
                logger.info("Saved final checkpoint.")
        if self.cfg.trainer.hf_save_interval > 0:
            with Timer("save_hf_model", self.all_timings):
                await asyncio.to_thread(self.save_models)
                logger.info("Saved final model.")
        logger.info("Training done!")

    async def _run_training(self, training_input: TrainingInputBatch):
        # TODO(Charlie): share this code with the one-step-off async trainer.
        # inference and calculate values, log probs, rewards, kl divergence
        with Timer("fwd_logprobs_values_reward", self.all_timings):
            training_input = await asyncio.to_thread(self.fwd_logprobs_values_reward, training_input)

        # calculate kl divergence and create experiences
        if self.cfg.trainer.algorithm.use_kl_in_reward:
            with Timer("apply_reward_kl_penalty", self.all_timings):
                training_input = self.apply_reward_kl_penalty(training_input)

        # calculate advantages and returns / along with tensorboard logging
        with Timer("compute_advantages_and_returns", self.all_timings):
            training_input = self.compute_advantages_and_returns(training_input)
            # remove some unwanted keys
            for key in ["rewards"]:
                training_input.pop(key)
            training_input.metadata.pop("uids")

            if self.cfg.trainer.algorithm.advantage_batch_normalize:
                training_input = normalize_advantages_dict(training_input)

        if self.cfg.trainer.dump_data_batch:
            # dump data to file
            with Timer("dump_data_batch"):
                self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")

        # train policy/critic model
        with Timer("train_critic_and_policy", self.all_timings):
            status = await asyncio.to_thread(self.train_critic_and_policy, training_input)

        return status

    async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: asyncio.Queue):
        """
        Generator worker: repeatedly pulls the next prompt (possibly blocked by staleness control),
        generates one single generation group, respecting a pause/resume event, and enqueues the result.
        """
        try:
            while True:
                # 0. Pull next batch from dataloader. If returns None, then dataloader is exhausted.
                rand_prompts = await self.async_train_dataloader.get_next_non_consumed_data()
                if rand_prompts is None:
                    return

                # 1. Prepare generator input
                assert len(rand_prompts) == 1
                generator_input, uids = prepare_generator_input(
                    rand_prompts,
                    self.cfg.generator.n_samples_per_prompt,
                    get_sampling_params_for_backend(self.cfg.generator.backend, self.cfg.generator.sampling_params),
                    self.cfg.environment.env_class,
                    "train",
                    self.global_step,
                )
                assert all(uid == uids[0] for uid in uids), "Expect all uids to be the same"

                # 2. Acquire capacity slot.
                slot_acquired = False
                await self._staleness_manager.acquire_submission_slot()
                slot_acquired = True

                # 3. Generate one rollout group
                global_step_at_start = self.global_step  # for staleness control

                if "disable_tqdm" in inspect.signature(self.generator.generate).parameters:
                    # A workaround to disable tqdm for the SkyRLGymGenerator.generate method which will
                    # blast the console with each worker's progress bar.
                    cur_generator_output: GeneratorOutput = await self.generator.generate(
                        generator_input, disable_tqdm=True
                    )
                else:
                    cur_generator_output: GeneratorOutput = await self.generator.generate(generator_input)

                # 4. Enqueue the completed group and mark accepted to free capacity slot.
                try:
                    generation_output_group_buffer.put_nowait(
                        GeneratedOutputGroup(
                            generator_output=cur_generator_output,
                            uid=uids[0],
                            global_step_when_scheduled=global_step_at_start,
                        )
                    )
                except asyncio.QueueFull:
                    raise AssertionError("Generation buffer should never be full given staleness control.")
                await self._staleness_manager.on_rollout_accepted()
        except asyncio.CancelledError:
            # If a slot was acquired but we exit early, release running count
            try:
                if "slot_acquired" in locals() and slot_acquired:
                    raise RuntimeError("Generation workers should only be cancelled when they finish running.")
            finally:
                return
        except Exception as e:
            logger.error(f"Generator worker errored out with exception: {e}")
            logger.error(f"Traceback: \n{traceback.format_exc()}")
            if "slot_acquired" in locals() and slot_acquired:
                raise RuntimeError("Generation workers should only run into error when they finish running.")
            sys.exit(1)

    async def async_sync_policy_weights_to_inference_engines(self):
        return await self.policy_model.async_run_method(
            "pass_through", "broadcast_to_inference_engines", self.inference_engine_client
        )

    def convert_generation_group_mini_batch_to_training_input(
        self, cur_generation_group_mini_batch: List[GeneratedOutputGroup]
    ) -> TrainingInputBatch:
        """Given a mini-batch of generated groups, concatenate them into a single GeneratorOutput, then convert to a TrainingInputBatch."""
        generator_outputs = []
        uids = []
        stalenesses = []
        staleness_violation_count = 0
        group_size = len(cur_generation_group_mini_batch[0].generator_output["response_ids"])
        for cur_generated_output_group in cur_generation_group_mini_batch:
            cur_staleness = self.global_step - cur_generated_output_group.global_step_when_scheduled
            stalenesses.append(cur_staleness)
            generator_outputs.append(cur_generated_output_group.generator_output)
            uids.extend([cur_generated_output_group.uid] * group_size)

            # Check staleness violation.
            if cur_staleness > self.max_staleness_steps:
                # TODO(Charlie): should we drop, drop and resample, or just log?
                logger.warning(
                    "Staleness control violated despite using AsyncStalenessManager: "
                    f"cur_staleness={cur_staleness}, max_staleness_steps={self.max_staleness_steps}.\n"
                    "If this happens too often, consider increasing max_staleness_steps, adjusting "
                    "trainer.fully_async.num_parallel_generation_workers, or adjusting generation-training GPU allocation.\n"
                    "See https://skyrl.readthedocs.io/en/latest/tutorials/fully_async.html#async-staleness-manager for more details."
                )
                staleness_violation_count += 1

        generator_output = concatenate_generator_outputs(generator_outputs)
        assert generator_output["rollout_metrics"] is not None, "Rollout metrics should be non-null."
        self.all_metrics.update(generator_output["rollout_metrics"])

        # Log staleness statistics for this step
        self.all_metrics.update(
            {
                "async/staleness_mean": sum(stalenesses) / len(stalenesses),
                "async/staleness_max": max(stalenesses),
                "async/staleness_min": min(stalenesses),
                "async/staleness_ratio": sum(1 for s in stalenesses if s > 0) / len(stalenesses),
                "async/staleness_violation_count": staleness_violation_count,
            }
        )

        # Convert rewards to per-token form and compute reward metrics before training conversion
        generator_output = self.postprocess_generator_output(generator_output, uids)

        # print example just for debugging
        vis = self.tokenizer.decode(generator_output["response_ids"][0])
        logger.info(f"Example generated: {vis}")

        return self.convert_to_training_input(generator_output, uids)

    def save_checkpoints(self):
        """
        Extend base checkpointing by recording consumed UIDs for fully-async training.

        Otherwise, when resuming, there is no way to know which data has been trained on.
        """
        consumed_uids_list = (
            self.async_train_dataloader.get_consumed_uids_list()
        )  # read first to prevent race condition
        # The base method will save the model, dataloader path, trainer_state, and latest_ckpt_global_step.txt.
        super().save_checkpoints()
        # In addition, we need to save the consumed UIDs -- the data that we have already trained on.
        global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
        fully_async_state_path = os.path.join(global_step_folder, "fully_async_state.pt")
        fully_async_state = {
            "consumed_uids": consumed_uids_list,
        }
        with io.open_file(fully_async_state_path, "wb") as f:
            torch.save(fully_async_state, f)
        logger.info(f"Saved fully-async state to {fully_async_state_path}")

    def load_checkpoints(self) -> Tuple[int, str, Set[str]]:
        """
        Load the base checkpoint without loading the dataloader state, and load the fully-async state.

        Returns the global step to resume from, the checkpoint path, and the consumed data UIDs.
        """
        global_step, checkpoint_path = super().load_checkpoints()
        if global_step == 0:
            return 0, checkpoint_path, None
        fully_async_state_path = os.path.join(checkpoint_path, "fully_async_state.pt")
        assert io.exists(fully_async_state_path), f"Fully-async state file not found at {fully_async_state_path}"
        with io.open_file(fully_async_state_path, "rb") as f:
            fully_async_state = torch.load(f, map_location="cpu", weights_only=False)
            assert "consumed_uids" in fully_async_state, "consumed_uids key not found in fully-async state"
            consumed_data_uids_set = set(fully_async_state["consumed_uids"])
        logger.info(f"Loaded fully-async state with {len(consumed_data_uids_set)} consumed UIDs")
        return global_step, checkpoint_path, consumed_data_uids_set
