from pathlib import Path
from verl.protocol import DataProto
class RolloutSkip:
    print_mark = "[RolloutSkip()]"
    def __init__(self, config, rollout_wg):
        self.rollout_config = config.actor_rollout_ref.rollout
        self.exp_name = config.data.get("experiment_name", "")
        self.project_name = config.data.get("project_name", "")
        self.n = int(self.rollout_config.get("n", 0))
        self.gbs = int(config.data.get("gen_batch_size", config.data.get("train_batch_size", 0)))
        self.dumped_dir = Path(self.rollout_config.get("skip_dump_dir", "/tmp/verl/rollout_dump"))
        self.dumped_dir.mkdir(parents=True, exist_ok=True)
        if str(self.dumped_dir.absolute()).startswith("/tmp/ray/session"):
            print(
                f"\033[33m{self.print_mark} Warning: \nUsing dump path ",
                f"'{self.dumped_dir.absolute()}' is not recommended ",
                "as it's located in /tmp/ray/session*\033[0m",
                flush=True,
            )
        print(
            f"{self.print_mark} Rollout skip dump path set to: ",
            f"{self.dumped_dir.absolute()}",
            flush=True,
        )
        self._rollout_wg = rollout_wg
    @property
    def curr_path_dump(self):
        return self.dumped_dir.joinpath(f"{self.exp_name}_{self.project_name}_GBS{self.gbs}__N{self.n}").absolute()
    def wrap_generate_sequences(self):
        try:
            self._rollout_wg.generate_sequences = wrap_generate_sequences(self, self._rollout_wg)
            print(
                f"{self.print_mark} Successfully patched `actor_rollout_wg.generate_sequences()`",
                flush=True,
            )
        except Exception as e:
            raise RuntimeError(
                "{self.print_mark} Failed to patch `actor_rollout_wg.generate_sequences()`",
                flush=True,
            ) from e
    def try_load(self):
        if not self.curr_path_dump.exists():
            print(
                f"{self.print_mark} No data dump found at {self.curr_path_dump}.",
                "The trainer will generate and automatically dump the data for this first run.",
                flush=True,
            )
            return None
        try:
            ret_batch = DataProto.load_from_disk(self.curr_path_dump)
            print(
                f"\033[32m{self.print_mark} Successfully load pre-generated data from {self.curr_path_dump}\033[0m",
                flush=True,
            )
            return ret_batch
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to load pre-generated data from {self.curr_path_dump}",
                f"Error: {str(e)}\033[0m",
                flush=True,
            )
            return None
    def dump(self, outputs: DataProto):
        try:
            outputs.save_to_disk(self.curr_path_dump)
            print(
                f"\033[32m{self.print_mark} Successfully dump data in {self.curr_path_dump}\033[0m",
                flush=True,
            )
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to dump data in {self.curr_path_dump}: {e}\033[0m",
                flush=True,
            )
def wrap_generate_sequences(rolloutskip: RolloutSkip, rollout_wg):
    generate_sequences = rollout_wg.generate_sequences
    def warp_fn(batch, **kwargs):
        gen_batch_output = rolloutskip.try_load()
        if gen_batch_output is None:
            gen_batch_output = generate_sequences(batch, **kwargs)
            rolloutskip.dump(gen_batch_output)
        return gen_batch_output
    return warp_fn