import json
import re
import os
import ast
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import argparse
from tqdm import tqdm

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


CALL_ASSERT_RE = re.compile(r'assert\s*[A-Za-z_]\w*\s*\(.*\)\s*==', re.DOTALL)


def assert_to_call_src(s: str) -> str:
    try:
        node = ast.parse(s.strip(), mode='exec')
        if not node.body:
            return ""
        stmt = node.body[0]
        if isinstance(stmt, ast.Assert):
            test = stmt.test
        else:
            return ""
        if isinstance(test, ast.Compare) and isinstance(test.left, ast.AST):
            call = test.left
            if isinstance(call, ast.Call):
                try:
                    return ast.unparse(call)
                except Exception:
                    m = re.search(r'assert\s*([A-Za-z_]\w*\s*\(.*\))\s*==', s, flags=re.DOTALL)
                    return m.group(1).strip() if m else ""
        return ""
    except Exception:
        return ""


def detect_format_from_examples(example_list: List[str]) -> Optional[str]:
    if not example_list:
        return None
    if any(CALL_ASSERT_RE.search(s or "") for s in example_list):
        return "CALL"
    if any(("STDIN:" in (s or "")) or re.search(r"(?i)\b(Input|输入)\b", s or "") for s in example_list):
        return "STDIN"
    return None


def has_function_definition(code: str) -> bool:
    return bool(re.search(r'^\s*def\s+[A-Za-z_]\w*\s*\(', code or "", flags=re.MULTILINE))


def decide_format_from_examples_or_code(example_list: List[str], actor_code: str) -> str:
    fmt = detect_format_from_examples(example_list)
    if fmt:
        return fmt
    return "CALL" if has_function_definition(actor_code) else "STDIN"


def extract_testcase_unified(completion: str) -> str:
    if not completion:
        return ""
    m = re.search(r"<testcase>(.*?)</testcase>", completion, re.DOTALL | re.IGNORECASE)
    if not m:
        return ""
    content = m.group(1).strip()
    content = re.sub(r"^\s*`{3,}.*?$", "", content, flags=re.MULTILINE).strip()
    if re.match(r"^\s*CALL:\s*.+", content):
        return re.sub(r"\s+", " ", content).strip()
    if content.startswith("STDIN:"):
        return content
    call = assert_to_call_src(content)
    if call:
        normalized_call = re.sub(r"\s+", " ", call).strip()
        return f"CALL: {normalized_call}"
    return ""


def parse_extracted_testcase(testcase_str: str) -> Optional[Dict[str, str]]:
    if not testcase_str:
        return None
    if testcase_str.startswith("CALL:"):
        return {"type": "CALL", "payload": testcase_str[5:].strip()}
    if testcase_str.startswith("STDIN:"):
        return {"type": "STDIN", "payload": testcase_str[6:].lstrip("\n")}
    return None


def build_chat_prompt(question: str, code: str, example_list: Optional[List[str]] = None, max_examples: int = 2) -> List[Dict[str, str]]:
    fmt = decide_format_from_examples_or_code(example_list or [], code)
    examples_text = ""
    if example_list and max_examples > 0:
        limited_examples = example_list[:max_examples]
        examples_text = "\n\nTest case examples (for format reference only):\n" + "\n".join(limited_examples)
    if fmt == "CALL":
        format_instruction = """
Then output exactly ONE failing test case inside <testcase> tags using this format:
<testcase>
CALL: func_name(arg1, arg2, kw=val)
</testcase>"""
    else:
        format_instruction = """
Then output exactly ONE failing test case inside <testcase> tags using this format:
<testcase>
STDIN:
<raw input here>
</testcase>"""
    user_content = f"""Analyze the given problem and generated code to find a test case that would cause the code to fail.

Problem: {question}

Generated code:
```python
{code}
