from pathlib import Path

from verl.protocol import DataProto


class RolloutSkip:
    """
    RolloutSkip skips sequence generation during rollout by attempting to load previously dumped data.
    If no dumped data is found, it generates new sequences and saves them to disk.

    Args:
        config: The configuration object containing rollout settings.
        rollout_wg: The worker group that handles the rollout process.

    Note:
        When rollout.n or rollout.gen_batch_size differ from previous runs,
        new sequences will be generated and saved with different filenames.
    """

    print_mark = "[RolloutSkip()]"

    def __init__(self, config, rollout_wg):
        self.rollout_config = config.actor_rollout_ref.rollout
        self.exp_name = config.data.get("experiment_name", "")
        self.project_name = config.data.get("project_name", "")

        self.n = int(self.rollout_config.get("n", 0))
        self.gbs = int(config.data.get("gen_batch_size", config.data.get("train_batch_size", 0)))

        self.dumped_dir = Path(self.rollout_config.get("skip_dump_dir", "/tmp/verl/rollout_dump"))
        self.dumped_dir.mkdir(parents=True, exist_ok=True)

        # Check if path is in Ray temporary directory
        if str(self.dumped_dir.absolute()).startswith("/tmp/ray/session"):
            print(
                f"\033[33m{self.print_mark} Warning: \nUsing dump path ",
                f"'{self.dumped_dir.absolute()}' is not recommended ",
                "as it's located in /tmp/ray/session*\033[0m",
                flush=True,
            )

        print(
            f"{self.print_mark} Rollout skip dump path set to: ",
            f"{self.dumped_dir.absolute()}",
            flush=True,
        )

        self._rollout_wg = rollout_wg

    @property
    def curr_path_dump(self):
        return self.dumped_dir.joinpath(f"{self.exp_name}_{self.project_name}_GBS{self.gbs}__N{self.n}").absolute()

    def wrap_generate_sequences(self):
        try:
            self._rollout_wg.generate_sequences = wrap_generate_sequences(self, self._rollout_wg)
            print(
                f"{self.print_mark} Successfully patched `actor_rollout_wg.generate_sequences()`",
                flush=True,
            )
        except Exception as e:
            raise RuntimeError(
                "{self.print_mark} Failed to patch `actor_rollout_wg.generate_sequences()`",
                flush=True,
            ) from e

    def try_load(self):
        if not self.curr_path_dump.exists():
            print(
                f"{self.print_mark} No data dump found at {self.curr_path_dump}.",
                "The trainer will generate and automatically dump the data for this first run.",
                flush=True,
            )
            return None

        try:
            # * Load
            ret_batch = DataProto.load_from_disk(self.curr_path_dump)
            print(
                f"\033[32m{self.print_mark} Successfully load pre-generated data from {self.curr_path_dump}\033[0m",
                flush=True,
            )
            return ret_batch
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to load pre-generated data from {self.curr_path_dump}",
                f"Error: {str(e)}\033[0m",
                flush=True,
            )
            return None

    def dump(self, outputs: DataProto):
        try:
            outputs.save_to_disk(self.curr_path_dump)
            print(
                f"\033[32m{self.print_mark} Successfully dump data in {self.curr_path_dump}\033[0m",
                flush=True,
            )
        except Exception as e:
            print(
                f"\033[31m{self.print_mark} Failed to dump data in {self.curr_path_dump}: {e}\033[0m",
                flush=True,
            )


def wrap_generate_sequences(rolloutskip: RolloutSkip, rollout_wg):
    generate_sequences = rollout_wg.generate_sequences

    def warp_fn(batch, **kwargs):
        gen_batch_output = rolloutskip.try_load()

        if gen_batch_output is None:
            # * 1. Generation
            gen_batch_output = generate_sequences(batch, **kwargs)
            # * 2. Dump
            rolloutskip.dump(gen_batch_output)
        return gen_batch_output

    return warp_fn
