from __future__ import annotations

import ast
import json
import re


def extract_solution(text: str) -> str | None:
    """
    Extract the solution from text with permissive whitespace matching.

    Finds the final occurrence of "solution = " (with flexible whitespace around '=')
    and returns all text from there to the end.

    Args:
        text: The text to search for a solution.

    Returns:
        The solution text (stripped), or None if no solution found.
    """
    # Match "solution" followed by optional whitespace, "=", optional whitespace
    # Case-insensitive matching
    pattern = r'solution\s*=\s*'
    matches = list(re.finditer(pattern, text, re.IGNORECASE))

    if not matches:
        return None

    # Get the last match and return everything after it
    last_match = matches[-1]
    solution = text[last_match.end():].strip()

    return solution if solution else None


def extract_balanced_brackets(text: str) -> str | None:
    """Extract the first balanced bracket content from text."""
    start_pos = text.find('[')
    if start_pos == -1:
        return None

    bracket_count = 0
    for i in range(start_pos, len(text)):
        if text[i] == '[':
            bracket_count += 1
        elif text[i] == ']':
            bracket_count -= 1
            if bracket_count == 0:
                return text[start_pos:i+1]

    return None


def parse_int_solution(lm_text: str) -> int | None:
    """
    Parse a decimal integer solution from LM output.

    Extracts text after "solution = " and parses the first decimal integer found.

    Args:
        lm_text: The full output from the language model

    Returns:
        The parsed integer, or None if parsing fails.
    """
    solution_text = extract_solution(lm_text)
    if not solution_text:
        return None

    # Match decimal integer (possibly negative)
    dec_match = re.search(r"-?\d+", solution_text)
    if dec_match:
        return int(dec_match.group(0))

    return None


def parse_hex_solution(lm_text: str) -> int | None:
    """
    Parse a hexadecimal integer solution from LM output.

    Extracts text after "solution = " and parses the first hex integer found.
    Accepts both "0x" prefixed and plain hex strings.

    Args:
        lm_text: The full output from the language model

    Returns:
        The parsed integer, or None if parsing fails.
    """
    solution_text = extract_solution(lm_text)
    if not solution_text:
        return None

    # Try with 0x prefix first
    hex_match = re.match(r'(0x[0-9a-fA-F]+)', solution_text)
    if hex_match:
        try:
            return int(hex_match.group(1), 16)
        except ValueError:
            pass

    # Try plain hex (no prefix)
    plain_hex_match = re.match(r'([0-9a-fA-F]+)', solution_text)
    if plain_hex_match:
        try:
            return int(plain_hex_match.group(1), 16)
        except ValueError:
            pass

    return None


def _convert_strings_to_ints(obj):
    """Recursively convert string digits to integers in nested lists."""
    if isinstance(obj, list):
        return [_convert_strings_to_ints(item) for item in obj]
    elif isinstance(obj, str):
        # Try to convert string to int (handles '5', '-3', etc.)
        try:
            return int(obj)
        except ValueError:
            return obj  # Keep as string if not a valid integer
    else:
        return obj


def parse_list_solution(lm_text: str) -> list | None:
    """
    Parse a list solution from LM output.

    Extracts text after "solution = " and parses the first balanced bracket
    content as JSON or Python literal. String digits are automatically
    converted to integers.

    Args:
        lm_text: The full output from the language model

    Returns:
        The parsed list with string digits converted to ints, or None if parsing fails.
    """
    solution_text = extract_solution(lm_text)
    if not solution_text:
        return None

    bracket_content = extract_balanced_brackets(solution_text)
    if not bracket_content:
        return None

    result = None

    # Try JSON first (double quotes)
    try:
        result = json.loads(bracket_content)
    except json.JSONDecodeError:
        pass

    # Fall back to ast.literal_eval (handles single quotes)
    if result is None:
        try:
            result = ast.literal_eval(bracket_content)
        except (ValueError, SyntaxError):
            pass

    if not isinstance(result, list):
        return None

    # Convert any string digits to integers
    return _convert_strings_to_ints(result)


PROMPT_TEMPLATE = """
You are being tested on your capacity for extended reasoning, involving logical deduction and state-tracking. You will be given a puzzle and asked to solve it. The puzzle is GUARANTEED to have a solution (potentially non-unique). You are not to use tools, write code, ask to use a solver, or ask any clarfying questions. You must solve the puzzle in a single response, or you will be deemed to have failed. There is no limit to the amount of time and tokens you can take to solve the puzzle.

REQUIREMENTS:
- Solve the puzzle within a single response.
- Do not use tools, write code, ask to use a solver, or ask any clarifying questions.
- Give your final answer in the format: solution = ... (the output format will be specified in the puzzle description).

Puzzle description:
{puzzle_description}

Example:
{example}

Puzzle instance:
{puzzle_instance}

Return your solution in the format: solution = ...
"""