"""
AlphaEdit-compatible token positioning utilities for MoEdit
"""

from typing import List, Dict, Tuple, Optional
from transformers import AutoTokenizer

# Import AlphaEdit's repr_tools for exact compatibility
try:
    from rome import repr_tools
except ImportError:
    repr_tools = None

# Import warning function
try:
    from .logger import warning
except ImportError:
    def warning(msg):
        print(f"WARNING: {msg}")


def find_fact_lookup_idx(
    prompt: str,
    subject: str,
    tok: AutoTokenizer,
    fact_token_strategy: str,
    verbose=True,
) -> int:
    """
    AlphaEdit-compatible fact lookup index computation.

    This implementation exactly matches AlphaEdit's compute_z.py logic
    to ensure complete consistency in token position calculation.

    Args:
        prompt: Template string with {} placeholder (e.g., "The capital of {} is")
        subject: Subject to substitute (e.g., "France")
        tok: Tokenizer
        fact_token_strategy: Strategy for finding the token
        verbose: Whether to print debug info

    Returns:
        Index of the target token in the full sequence
    """
    ret = None

    if fact_token_strategy == "last":
        ret = -1
    elif (
        "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
    ):
        # Use AlphaEdit's original repr_tools implementation for exact compatibility
        if repr_tools is not None:
            ret = repr_tools.get_words_idxs_in_templates(
                tok=tok,
                context_templates=[prompt],
                words=[subject],
                subtoken=fact_token_strategy[len("subject_") :],
            )[0][0]
        else:
            # Fallback to local implementation if repr_tools not available
            ret = get_words_idxs_in_templates(
                tok=tok,
                context_templates=[prompt],
                words=[subject],
                subtoken=fact_token_strategy[len("subject_") :],
            )[0][0]
    else:
        raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    # Format sentence exactly like AlphaEdit
    sentence = prompt.format(subject)

    if verbose:
        print(
            f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
            tok.decode(tok(sentence)["input_ids"][ret]),
        )

    return ret


def get_words_idxs_in_templates(
    tok: AutoTokenizer,
    context_templates: List[str],
    words: List[str],
    subtoken: str
) -> List[List[int]]:
    """
    AlphaEdit-compatible word index computation in templates.

    This implementation exactly matches AlphaEdit's rome/repr_tools.py logic
    to ensure complete consistency in token position calculation.

    Args:
        tok: Tokenizer
        context_templates: List of template strings with {} placeholder
        words: List of words to substitute
        subtoken: Which token to return ("first", "last", "first_after_last")

    Returns:
        List of lists containing token indices
    """
    # Strict compatibility check - exactly like AlphaEdit
    assert all(
        tmp.count("{}") == 1 for tmp in context_templates
    ), "We currently do not support multiple fill-ins for context"

    prefixes_len, words_len, suffixes_len, inputs_len = [], [], [], []

    for i, context in enumerate(context_templates):
        # Split template exactly like AlphaEdit
        prefix, suffix = context.split("{}")

        # Calculate lengths exactly like AlphaEdit
        prefix_len = len(tok.encode(prefix))
        prompt_len = len(tok.encode(prefix + words[i]))
        input_len = len(tok.encode(prefix + words[i] + suffix))

        prefixes_len.append(prefix_len)
        words_len.append(prompt_len - prefix_len)
        suffixes_len.append(input_len - prompt_len)
        inputs_len.append(input_len)

    # Compute indices exactly like AlphaEdit
    if subtoken == "last" or subtoken == "first_after_last":
        return [
            [
                prefixes_len[i]
                + words_len[i]
                - (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
            ]
            # If suffix is empty, there is no "first token after the last".
            # So, just return the last token of the word.
            for i in range(len(context_templates))
        ]
    elif subtoken == "first":
        return [[prefixes_len[i] - inputs_len[i]] for i in range(len(context_templates))]
    else:
        raise ValueError(f"Unknown subtoken type: {subtoken}")


def find_subject_token_positions_alphaedit_style(
    tokenizer: AutoTokenizer,
    prompt: str,
    subject: str,
    inputs: Dict,
    fact_token_strategy: str = "subject_last",
    verbose: bool = False
) -> Tuple[int, int]:
    """
    AlphaEdit-style subject token position finding.
    
    Args:
        tokenizer: The tokenizer
        prompt: Template string with {} placeholder
        subject: Subject string
        inputs: Tokenized inputs (for compatibility, not used in AlphaEdit style)
        fact_token_strategy: Token positioning strategy
        verbose: Whether to print debug info
        
    Returns:
        Tuple of (start_pos, end_pos) where end_pos is the target token position
    """
    # Use AlphaEdit's fact lookup strategy
    target_idx = find_fact_lookup_idx(
        prompt=prompt,
        subject=subject,
        tok=tokenizer,
        fact_token_strategy=fact_token_strategy,
        verbose=verbose
    )
    
    # For compatibility with MoEdit's interface, we need to return start and end positions
    if fact_token_strategy == "subject_last":
        # Find the start position by computing subject_first
        start_idx = find_fact_lookup_idx(
            prompt=prompt,
            subject=subject,
            tok=tokenizer,
            fact_token_strategy="subject_first",
            verbose=False
        )
        return start_idx, target_idx
    elif fact_token_strategy == "subject_first":
        # For subject_first, start and end are the same
        return target_idx, target_idx
    elif fact_token_strategy == "last":
        # For "last" strategy, we don't have a specific subject range
        # Return the last token position for both start and end
        return target_idx, target_idx
    else:
        # For other strategies, assume single token
        return target_idx, target_idx


def get_last_subject_token_position(
    tokenizer: AutoTokenizer,
    prompt: str,
    subject: str,
    inputs: Dict,
    fact_token_strategy: str = "subject_last",
    verbose: bool = False
) -> int:
    """
    Get the last subject token position using AlphaEdit-style strategy.
    
    This is the main function that should be used to replace MoEdit's
    find_subject_token_positions when only the last token position is needed.
    
    Args:
        tokenizer: The tokenizer
        prompt: Template string with {} placeholder
        subject: Subject string
        inputs: Tokenized inputs (for compatibility)
        fact_token_strategy: Token positioning strategy
        verbose: Whether to print debug info
        
    Returns:
        Index of the last subject token (or target token based on strategy)
    """
    return find_fact_lookup_idx(
        prompt=prompt,
        subject=subject,
        tok=tokenizer,
        fact_token_strategy=fact_token_strategy,
        verbose=verbose
    )
