from typing import Optional, Tuple

import torch


def get_fixed_positions(
    tokenizer,
    full_text: str,
    prefix_marker: Optional[str] = None,
    postfix_marker: Optional[str] = None,
    start_gap: int = 0,
    end_gap: int = 0,
    verbose: bool = False,
) -> Tuple[torch.Tensor, list[bool]]:
    """
    Determines which tokens in the full text should be fixed vs. variable based on markers.

    Args:
        tokenizer: The tokenizer to use for encoding/decoding
        full_text: The complete text to analyze
        prefix_marker: Text marking the start of the variable region (None if variable region starts at beginning)
        postfix_marker: Text marking the end of the variable region (None if variable region extends to the end)

    Returns:
        Tuple of (token_ids, fixed_positions) where fixed_positions is a list of booleans
        indicating which positions should remain fixed (True) vs. variable (False)
    """
    # Convert full text to token ids
    full_ids = torch.tensor(tokenizer.encode(full_text))

    # Initialize empty lists
    fixed_positions = []
    reconstructed_text = ""

    # Find the boundaries of the variable section
    variable_start_max = 0  # Default to beginning of text
    if prefix_marker is not None:
        variable_start_max = full_text.find(prefix_marker) + len(prefix_marker)
        if variable_start_max < len(prefix_marker):
            raise ValueError(
                f"Could not find prefix marker '{prefix_marker}' in the text"
            )

    variable_end_min = len(full_text)  # Default to end of text
    if postfix_marker is not None:
        variable_end_min = full_text.find(postfix_marker)
        if variable_end_min < 0:
            raise ValueError(
                f"Could not find postfix marker '{postfix_marker}' in the text"
            )

    # Process each token
    token_start_pos = 0

    for i, token_id in enumerate(full_ids):
        # Get the text for this token
        token_text = tokenizer.decode([token_id])
        # Skip special tokens in position tracking
        if token_text.strip() in ["<bos>", "<eos>", "<pad>"]:
            fixed_positions.append(True)  # Mark special tokens as fixed

            continue

        reconstructed_text += token_text
        token_end_pos = token_start_pos + len(token_text)

        # Determine if token is fixed or variable
        if (
            token_start_pos >= variable_start_max - start_gap
            and token_end_pos <= variable_end_min + end_gap
        ):
            # Token is fully within variable region
            fixed_positions.append(False)
        elif (
            token_start_pos >= variable_end_min - end_gap
            or token_end_pos <= variable_start_max + start_gap
        ):
            # Token is fully outside variable region
            fixed_positions.append(True)

        else:
            # Token spans boundary between fixed and variable regions
            print("\nTokens up to error:")
            for j in range(i + 1):
                prev_token_text = tokenizer.decode([full_ids[j]])
                print(f"Token {j}: '{prev_token_text}'")
            print(f"\nProblem occurred at token {i}: '{token_text}'")
            raise ValueError(
                f"Token '{token_text}' spans boundary between fixed and variable regions"
            )

        token_start_pos += len(token_text)

    if verbose:
        # Debug print of tokens and their fixed positions
        for i, (token_id, is_fixed) in enumerate(zip(full_ids, fixed_positions)):
            token_text = tokenizer.decode([token_id])
            print(f"Token {i}: '{token_text}' - Fixed: {is_fixed}")

    return full_ids, fixed_positions


def generate_new_text(
    model, tokenizer, input_text, temperature=0.0, max_tokens=10, verbose=False
):

    # Generate text
    generated_text = model.generate(
        input_text,
        max_new_tokens=max_tokens,
        temperature=temperature,
        verbose=verbose,
    )

    # Extract just the newly generated text
    new_text = generated_text[len(input_text) :]
    return new_text
