"""Completion reconstruction for different training modes."""

from typing import List, Tuple

def find_first_closing_brace_pos(text: str) -> int:
    """
    Find the character position of the closing brace of the first balanced {} pair.
    
    This function finds the first opening { and then tracks the balance until
    it finds the matching closing }.
    
    Args:
        text: The text to search
        
    Returns:
        Character position of the closing } of the first balanced pair, or -1 if not found
    """
    if not text:
        return -1
    
    brace_count = 0
    found_opening = False
    
    for i, char in enumerate(text):
        if char == '{':
            brace_count += 1
            found_opening = True
        elif char == '}':
            if found_opening:
                brace_count -= 1
                if brace_count == 0:
                    return i
    
    # No balanced pair found
    return -1

def truncate_rollout_to_first_brace(
    rollout_text: str, 
    tokenizer
) -> List[int]:
    """
    Truncate rollout token IDs to the first unbalanced } symbol.
    
    Args:
        rollout_text: The rollout text to find the brace position in
        tokenizer: Tokenizer for encoding and offset mapping
        
    Returns:
        The truncated rollout token IDs, or all token IDs if no unbalanced } is found
    """
    if not rollout_text:
        return []
    
    brace_char_pos = find_first_closing_brace_pos(rollout_text)
    if brace_char_pos < 0:
        return tokenizer.encode(rollout_text, add_special_tokens=False)
    
    # Use offset mapping to find which token contains the brace position
    encoding = tokenizer(rollout_text, add_special_tokens=False, return_offsets_mapping=True)
    token_ids = encoding['input_ids']
    offset_mapping = encoding['offset_mapping']
    
    for token_idx, (start, end) in enumerate(offset_mapping):
        if start <= brace_char_pos < end:
            return token_ids[:token_idx + 1]
        elif start > brace_char_pos:
            return token_ids[:token_idx]
    
    return token_ids

def reconstruct_completions(
    tokenizer,
    prompt_ids: List[List[int]],
    completion_ids: List[List[List[int]]],
    rollout_texts: List[List[List[str]]] = None,
    delimiter_token_id: int = None,
    rollout_append_text: str = "",
    rollout_rewards: List[List[List[int]]] = None,
    rollout_bin_indices: List[List[List[int]]] = None,
) -> Tuple[List[List[List[int]]], List[List[List[int]]], List[List[List[int]]]]:
    """
    Token-based reconstruction that returns tokenized sequences and labels.
    
    Args:
        tokenizer: The tokenizer to use
        prompt_ids: List of prompt token IDs [prompt_idx]
        completion_ids: Token IDs for completions [prompt_idx][completion_idx]
        rollout_texts: Rollout texts [prompt_idx][completion_idx][paragraph_idx]
        rollout_token_ids: Rollout token IDs [prompt_idx][completion_idx][paragraph_idx]
        delimiter_token_id: Token ID used as paragraph delimiter
        rollout_append_text: Text to append before rollouts
        rollout_rewards: Rollout rewards [prompt_idx][completion_idx][paragraph_idx]
        rollout_bin_indices: Rollout bin indices [prompt_idx][completion_idx][paragraph_idx]
    Returns:
        Tuple of (input_ids, labels, section_labels) where:
        - input_ids: Full tokenized sequences [prompt_idx][completion_idx]
        - labels: Token labels [prompt_idx][completion_idx] 
        - section_labels: Section labels for each token [prompt_idx][completion_idx]
    """
    rollout_append_tokens = tokenizer.encode(rollout_append_text, add_special_tokens=False)
    
    input_ids = []
    labels = []
    section_labels = []
    bin_indices = []
    
    for prompt_ids_i, completion_ids_i, rollout_texts_i, rollout_rewards_i, rollout_bin_indices_i in zip(prompt_ids, completion_ids, rollout_texts, rollout_rewards, rollout_bin_indices):
        input_ids_i = []
        labels_i = []
        bin_indices_i = []
        section_labels_i = []

        for completion_ids_ij, rollout_text_ij, rollout_rewards_ij, rollout_bin_indices_ij in zip(completion_ids_i, rollout_texts_i, rollout_rewards_i, rollout_bin_indices_i):
            input_ids_ij, labels_ij, bin_indices_ij, section_labels_ij = reconstruct_completions_ij(
                tokenizer, prompt_ids_i, completion_ids_ij, rollout_text_ij, delimiter_token_id, rollout_append_tokens, rollout_rewards_ij, rollout_bin_indices_ij,
            )
            input_ids_i.append(input_ids_ij)
            labels_i.append(labels_ij)
            bin_indices_i.append(bin_indices_ij)
            section_labels_i.append(section_labels_ij)

        input_ids.append(input_ids_i)
        labels.append(labels_i)
        bin_indices.append(bin_indices_i)
        section_labels.append(section_labels_i)

    return input_ids, labels, bin_indices, section_labels

def reconstruct_completions_ij(
    tokenizer,
    prompt_tokens: List[int],
    completion_tokens: List[int],
    rollout_texts: List[str],
    delimiter_token_id: int,
    rollout_append_tokens: List[int],
    rollout_rewards: List[int],
    rollout_bin_indices: List[int],
) -> Tuple[List[int], List[int], List[int], List[int]]:
    """
    Reconstruction function for inj_roll_trunc mode only.
    
    Returns:
        Tuple of (full_tokens, labels, section_labels)
    """
    # Initialize common structures
    full_tokens = prompt_tokens.copy()
    labels = [-100] * len(prompt_tokens)  # No labels for prompt
    bin_indices = [-100] * len(prompt_tokens) # No bin indices for prompt
    section_labels = [0] * len(prompt_tokens)  # 0: prompt
    
    delimiter_positions = [i for i, token_id in enumerate(completion_tokens) if token_id == delimiter_token_id]
    paragraph_positions = [0] + [i+1 for i in delimiter_positions] + [len(completion_tokens)] # positions where paragraphs start

    assert len(rollout_texts) == len(paragraph_positions) - 1, f"Number of rollout texts ({len(rollout_texts)}) does not match number of paragraphs ({len(paragraph_positions) - 1})"
    assert len(rollout_rewards) == len(paragraph_positions) - 1, f"Number of rollout rewards ({len(rollout_rewards)}) does not match number of paragraphs ({len(paragraph_positions) - 1})"
    assert len(rollout_bin_indices) == len(paragraph_positions) - 1, f"Number of rollout bin indices ({len(rollout_bin_indices)}) does not match number of paragraphs ({len(paragraph_positions) - 1})"
    
    for k in range(len(paragraph_positions) - 1):
        paragraph_completion_tokens = completion_tokens[paragraph_positions[k]:paragraph_positions[k+1]-1] # exclude the delimiter token
        full_tokens.extend(paragraph_completion_tokens)
        section_labels.extend([1] * len(paragraph_completion_tokens))  # 1: completion
        labels.extend([-100] * len(paragraph_completion_tokens))
        bin_indices.extend([-100] * len(paragraph_completion_tokens))
        
        truncated_rollout_tokens = truncate_rollout_to_first_brace(rollout_texts[k], tokenizer)
        
        full_tokens.extend(rollout_append_tokens)
        section_labels.extend([3] * len(rollout_append_tokens))  # 3: append_text
        labels.extend([-100] * len(rollout_append_tokens))
        bin_indices.extend([-100] * len(rollout_append_tokens))

        full_tokens.extend(truncated_rollout_tokens)
        section_labels.extend([4] * len(truncated_rollout_tokens))  # 4: rollout
        labels.extend([-100] * (len(truncated_rollout_tokens) - 1))
        bin_indices.extend([-100] * (len(truncated_rollout_tokens) - 1))
        labels.append(int(rollout_rewards[k]))
        bin_indices.append(rollout_bin_indices[k])

        if k < len(paragraph_positions) - 2:
            full_tokens.append(delimiter_token_id)
            section_labels.append(2) # 2: delimiter
            labels.append(-100)
            bin_indices.append(-100)
    
    return full_tokens, labels, bin_indices, section_labels
