from transformers import AutoTokenizer

from slime.utils.mask_utils import MultiTurnLossMaskGenerator

__all__ = ["generate_rollout"]


TOKENIZER = None


def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
    """CPT Tokenized rollout function.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_buffer: the data buffer to store the generated samples
        evaluation: bool, whether the rollout is for evaluation or not

    Returns:
        list[Sample]: a list of samples generated by the rollout
    """
    assert not evaluation
    assert args.rollout_global_dataset

    samples = data_buffer.get_samples(args.rollout_batch_size)
    token_ids_list = [ sample[0].prompt for sample in samples]

    for sample, token_ids in zip(samples, token_ids_list):
        (sample,) = sample
        response_length = len(token_ids)
        loss_mask = [1] * response_length

        sample.tokens = token_ids
        sample.response_length = response_length - 1  # mask the first token otherwise a shape mismatch will happen
        sample.reward = 0
        sample.loss_mask = loss_mask[1:]

    return samples
