from llm_graph_optimizer.operations.helpers.exceptions import OperationFailed


def io_prompt(input_list: list[int]):
    return f"""Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input_list}
"""

def cot_prompt(input_list: list[int]):
    return f"""Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input_list}
"""

def propose_prompt(num_examples: int, input_list: list[int]):
    developer_instructions = f"Follow exactly the few shot prompt. Output exactly {num_examples} next steps."
    user_content = f"""Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 /  2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Input: {input_list}
Possible next steps:
"""

    return [
        {"role": "developer", "content": developer_instructions},
        {"role": "user", "content": user_content},
    ]

def propose_parser(response: str):
    # Parse multiple lines like: "14 - 2 = 12 (left: 8 8 12)"
    import re
    expressions: list[str] = []
    lefts: list[list[int]] = []

    lines = [ln.strip() for ln in response.strip().splitlines() if ln.strip()]
    for line in lines:
        if "=" not in line:
            continue
        # Extract left list
        left_match = re.search(r"\(left:\s*([^\)]+)\)", line, flags=re.IGNORECASE)
        if left_match:
            left_str = left_match.group(1)
            try:
                left = [int(x) for x in re.split(r'[,\s\[\]]+', left_str) if x]
            except ValueError:
                # skip lines with malformed left values
                continue
            expr = line[: left_match.start()].strip()
        else:
            # No explicit left list; attempt to parse numbers after '=' as fallback
            try:
                right_part = line.split("=", 1)[1].strip()
                # take only digits and spaces up to a paren if present
                right_part = right_part.split("(")[0].strip()
                left = [int(x) for x in right_part.split()]
            except Exception:
                continue
            expr = line.split("(")[0].strip()

        # Normalize expression to exclude trailing punctuation
        expr = expr.rstrip(";,")

        if expr and isinstance(left, list):
            expressions.append(expr)
            lefts.append(left)
    if expressions == []:
        raise OperationFailed(f"Failed to parse expressions from response: {response}")
    if len(expressions) != len(lefts):
        raise OperationFailed(f"Number of expressions and lefts do not match: {len(expressions)} != {len(lefts)}")
    return {"expressions": expressions, "lefts": lefts}

def value_prompt(left: list[int]):
    system_message = "Follow exactly the few shot prompt."
    user_message = f"""Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{left}
"""
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

def value_parser(response: str):
    value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}
    lines = [ln.strip().lower() for ln in response.strip().splitlines() if ln.strip()]
    last = lines[-1] if lines else ""
    for key, val in value_map.items():
        if key in last:
            return {"value": val}
    return {"value": 0}

def value_last_step_prompt(left: list[int], answer: str):
    system_message = "Follow exactly the few shot prompt."
    user_message = f"""Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge: 
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge: 
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge: 
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge: 
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge: 
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge: 
impossible
Input: {left}
Answer: {answer}
Judge:"""
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

def value_last_step_parser(response: str):
    value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}
    lines = [ln.strip().lower() for ln in response.strip().splitlines() if ln.strip()]
    last = lines[-1] if lines else ""
    for key, val in value_map.items():
        if key in last:
            return {"value": val}
    return {"value": 0}