from typing import Optional, Tuple

import torch


def get_fixed_positions(
    tokenizer,
    full_text: str,
    region_start: Optional[str] = None,
    region_end: Optional[str] = None,
    start_gap: int = 0,
    end_gap: int = 0,
    verbose: bool = False,
    skip_special_tokens: bool = True,
) -> Tuple[torch.Tensor, list[bool]]:
    """
    Determines which tokens in the full text should be fixed vs. variable based on markers.

    Args:
        tokenizer: The tokenizer to use for encoding/decoding
        full_text: The complete text to analyze
        region_start: Text marking the start of the variable region (None if variable region starts at beginning)
        region_end: Text marking the end of the variable region (None if variable region extends to the end)

    Returns:
        Tuple of (token_ids, fixed_positions) where fixed_positions is a list of booleans
        indicating which positions should remain fixed (True) vs. variable (False)
    """
    # Convert full text to token ids
    full_ids = torch.tensor(
        tokenizer.encode(full_text, add_special_tokens=skip_special_tokens)
    )

    # Initialize empty lists
    fixed_positions = []
    reconstructed_text = ""

    # Find the boundaries of the variable section
    if region_start is None:
        region_start = 0

    if region_end is None:
        region_end = len(full_text)

    # Process each token
    token_start_pos = 0

    for i, token_id in enumerate(full_ids):
        # Get the text for this token
        token_text = tokenizer.decode([token_id])
        # Skip special tokens in position tracking
        if skip_special_tokens and token_text.strip() in ["<bos>", "<eos>", "<pad>"]:
            fixed_positions.append(True)  # Mark special tokens as fixed

            continue

        reconstructed_text += token_text
        token_end_pos = token_start_pos + len(token_text)

        # Determine if token is fixed or variable
        if (
            token_start_pos >= region_end - end_gap
            or token_end_pos <= region_start + start_gap
        ):
            # Token is fully outside variable region
            fixed_positions.append(True)
        elif (
            token_start_pos >= region_start - start_gap
            and token_end_pos <= region_end + end_gap
        ):
            # Token is fully within variable region
            fixed_positions.append(False)

        else:
            if verbose:
                print("\nTokens up to error:")
                for j in range(i + 1):
                    prev_token_text = tokenizer.decode([full_ids[j]])
                    print(f"Token {j}: '{prev_token_text}'")
                print(f"\nProblem occurred at token {i}: '{token_text}'")
            raise ValueError(
                f"Token '{token_text}' spans boundary between fixed and variable regions"
            )

        token_start_pos += len(token_text)

    if verbose:
        # Debug print of tokens and their fixed positions
        for i, (token_id, is_fixed) in enumerate(zip(full_ids, fixed_positions)):
            token_text = tokenizer.decode([token_id])
            print(f"Token {i}: '{token_text}' - Fixed: {is_fixed}")

    return full_ids, fixed_positions


def process_text_with_placeholder(
    tokenizer,
    template,
    variable_context,
    start_gap=0,
    end_gap=0,
    verbose=False,
    skip_special_tokens=True,
):
    # Find the placeholder
    placeholder = "{0}"
    placeholder_start = template.find(placeholder)
    if placeholder_start == -1:
        raise ValueError(f"Placeholder '{placeholder}' not found in text")

    placeholder_end = placeholder_start + len(variable_context)

    return get_fixed_positions(
        tokenizer,
        template.format(variable_context),
        placeholder_start,
        placeholder_end,
        start_gap,
        end_gap,
        verbose,
        skip_special_tokens,
    )
