# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Post-processing function for vLLM rollout outputs.
"""

import re
import random
import math
from typing import List, Any


def truncate(text: str) -> str:
    """
    Truncate text at "Step 7" and replace with a predefined response.
    
    Args:
        text: Input text to process
        
    Returns:
        Modified text with truncation at Step 7
    """
    parts = re.split(r"(\bStep\s*7)", text, maxsplit=1)
    if len(parts) == 3:
        return parts[0] + parts[1] + ": I have used all 6 reasoning steps as instructed and cannot proceed further. I need external assistance."
    return text


def post_process_rollout(
    outputs: List[Any],
    tokenizer,
    max_response_length: int = 1024,  # Add max_response_length parameter
    rng: random.Random = random,
) -> List[List[int]]:
    """
    Post-processing function for rollout outputs.
    
    This function processes vLLM outputs and applies custom logic:
    - Identifies responses containing "Step 7"
    - Randomly selects half of them for truncation
    - Re-tokenizes the modified responses
    - Ensures responses don't exceed max_response_length
    
    Args:
        outputs: List of vLLM output objects from inference_engine.generate()
        tokenizer: Tokenizer object for re-tokenization
        max_response_length: Maximum allowed response length in tokens
        rng: Random number generator (defaults to random module)
        
    Returns:
        response: List of token ID lists ready for training
    """
    responses: List[List[int]] = []

    for out_idx, out in enumerate(outputs):
        # Find indices of samples containing "Step 7"
        step7_idxs = [0] if "Step 7" in out.outputs[0].text else []
        
        # Select which to truncate
        if len(step7_idxs) == 1:
            # If only one, randomly decide whether to truncate (50% chance)
            selected = set(step7_idxs) if rng.random() < 0.5 else set()
        else:
            selected = set()

        # Process the single sample
        sample = out.outputs[0]
        
        if 0 in selected:
            # Truncate and re-tokenize
            modified = truncate(sample.text)
            try:
                token_ids = tokenizer.encode(modified, add_special_tokens=False)
            except Exception as e:
                # Fallback to original token_ids if encoding fails
                token_ids = sample.token_ids
        else:
            # Use original token_ids
            token_ids = sample.token_ids

        # Ensure EOS token is present
        if token_ids and token_ids[-1] != tokenizer.eos_token_id:
            token_ids.append(tokenizer.eos_token_id)
        
        # CRITICAL FIX: Truncate if response exceeds max_response_length
        if len(token_ids) > max_response_length:
            # Truncate to max_response_length, ensuring EOS token is preserved
            token_ids = token_ids[:max_response_length-1] + [tokenizer.eos_token_id]
        
        responses.append(token_ids)

    return responses 