"""Zero-shot / few-shot inference for Kimina-Prover via vLLM.

Shares the same parser, prompt format, and inference parameters as the
R5 script (gpu_inference_paperorig_random5_numina.py) for fair comparison.
The only difference is example handling:
  - ZS (--use_examples_in_prompt 0): no examples in prompt
  - FS (--use_examples_in_prompt 1): hardcoded 3 examples
"""

from tqdm import tqdm
import argparse
import json
import os
import re


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run ZS/FS inference with Kimina-Prover via vLLM"
    )
    parser.add_argument("--port", type=int, default=8000, help="Port for vLLM")
    parser.add_argument(
        "--num_samples_per_task",
        type=int,
        default=32,
        help="Number of samples to generate per task",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="AI-MO/Kimina-Prover-RL-1.7B",
        help="Model identifier",
    )
    parser.add_argument(
        "--method_tag",
        type=str,
        default="OffTheShelfWithEg",
        help="Tag to identify evaluation method",
    )
    parser.add_argument(
        "--eval_dir",
        type=str,
        default="results_miniF2F",
        help="Directory to store inference results",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="./datasets_validation/minif2f/dataset.jsonl",
        help="Path to the dataset file",
    )
    parser.add_argument(
        "--use_examples_in_prompt",
        type=int,
        choices=[0, 1],
        default=1,
        help="Set 1 to use hardcoded examples in prompt, 0 to disable",
    )
    return parser.parse_args()


# ── Shared parser (identical to R5 script) ──────────────────────────────────

def _clean_lean_code(code):
    """Remove non-Lean content that may be captured by fallback strategies."""
    code = re.sub(r"</formal_(theorem|proof)>.*", "", code, flags=re.DOTALL)
    code = re.sub(r"^<formal_(theorem|proof)>\s*", "", code)
    code = re.sub(r"^```(lean4?)?\s*", "", code)
    code = re.sub(r"\s*```\s*$", "", code)
    lines = code.split("\n")
    clean_lines = []
    for line in lines:
        stripped = line.strip()
        if stripped and not stripped.startswith("--") and re.match(
            r"^(This|Note|The above|Here|I |We |In |My |Above|Proof|QED|Q\.E\.D)",
            stripped
        ):
            break
        clean_lines.append(line)
    return "\n".join(clean_lines).strip()


def extract_fl_proof(text):
    """Extract Lean code with multiple fallback strategies."""
    if not text:
        return ""
    text = str(text)

    # Strip <think>...</think> blocks (Kimina thinking mode)
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

    # Strategy 1: ```lean4 ... ```
    matches = re.findall(r"```lean4\s*(.*?)```", text, re.DOTALL)
    if matches:
        return matches[-1].strip()

    # Strategy 2: ``` ... ``` containing lean keywords
    matches2 = re.findall(r"```\s*(.*?)```", text, re.DOTALL)
    for m in reversed(matches2):
        if 'import' in m or 'theorem' in m:
            return m.strip()

    # Strategy 3: <formal_proof>/<formal_theorem> tag
    for tag in ("formal_proof", "formal_theorem"):
        start_delim = f"<{tag}>"
        end_delim = f"</{tag}>"
        if start_delim in text and end_delim in text:
            inner = text.split(start_delim)[-1].split(end_delim)[0].strip()
            if inner:
                inner_fenced = re.findall(r"```(?:lean4)?\s*(.*?)```", inner, re.DOTALL)
                if inner_fenced:
                    return inner_fenced[-1].strip()
                return _clean_lean_code(inner)

    # Strategy 4: import Mathlib
    if "import Mathlib" in text:
        return _clean_lean_code(text[text.index("import Mathlib"):])

    # Strategy 5: import Aesop
    if "import Aesop" in text:
        return _clean_lean_code(text[text.index("import Aesop"):])

    # Strategy 6: theorem keyword
    if re.search(r'\btheorem\s+\w+', text):
        match = re.search(r'(theorem\s+\w+.*)', text, re.DOTALL)
        if match:
            return _clean_lean_code(match.group(1))

    return ""


# ── Shared infrastructure (identical to R5 script) ──────────────────────────

def prepare_jsonl_to_evaluation(dataset_jsonl_path, eval_path, model_name, method_tag, num_samples):
    eval_filename = f"{model_name}-{method_tag}_output.jsonl"
    eval_file_path = os.path.join(eval_path, eval_filename)

    if os.path.exists(eval_file_path):
        with open(eval_file_path, "r", encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
        print("Output file already exists, resuming!")
        return data, eval_file_path

    with open(dataset_jsonl_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]

    for data_instance in data:
        data_instance["INFERENCE_DONE"] = "no"
        for i in range(1, num_samples + 1):
            for col_base in [
                "LLM_Output#",
                "LLM_Syntax?#",
                "LLM_SyntaxError#",
                "LLM_Semantics?#",
                "LLM_SemanticsError#",
            ]:
                data_instance[f"{col_base}{i}"] = ""

    with open(eval_file_path, "w", encoding="utf-8") as f:
        for data_instance in data:
            f.write(json.dumps(data_instance) + "\n")

    print(f"Modified file saved to {eval_file_path}")
    return data, eval_file_path


def save_jsonl(list_of_dicts, file_path):
    with open(file_path, "w", encoding="utf-8") as f:
        for data_instance in list_of_dicts:
            f.write(json.dumps(data_instance) + "\n")


# ── Prompt construction ─────────────────────────────────────────────────────

# Hardcoded 3-shot examples (same as original paper)
HARDCODED_EXAMPLES = '''### Example 1
<informal_statement>
Find the sum of all positive integers $n$ such that $\\sqrt{n^2+85n+2017}$ is an integer.
</informal_statement>

<informal_proof>
We are looking for all positive integers $n$ such that $\\sqrt{n^2+85n+2017}$ is an integer. Let this integer be $k$. Thus, we have the equation $k^2 = n^2+85n+2017$.

To find the values of $n$, we algebraically manipulate the equation. We aim to express it in a form that allows us to use factorization. The key step is to transform this equation into:
$(2k - 2n - 85)(2k + 2n + 85) = 843$

Let $a = 2k - 2n - 85$ and $b = 2k + 2n + 85$. Then $ab = 843$.
We consider pairs of factors $(a, b)$ of 843 such that $ab = 843$ and $b > a$.
The factor pairs yield $n=168$ and $n=27$.
Sum $= 168 + 27 = 195$.
</informal_proof>

<formal_proof>
```lean4
import Mathlib
/-- Find the sum of all positive integers $n$ such that $\\sqrt{n^2+85n+2017}$ is an integer. -/
theorem algebra_53827 :
\\u2211\\u1da0 n \\u2208 {n : \\u2115 | 0 < n \\u2227 \\u2203 k, k^2 = (\\u2191n^2 : \\u2124) + 85 * n + 2017}, n = 195 := by
  have ta {a b : \\u2124} (h : a * b = 843) :
      a = 1 \\u2228 a = 3 \\u2228 a = 281 \\u2228 a = 843 \\u2228 a = -1 \\u2228 a = -3 \\u2228 a = -281 \\u2228 a = -843 := by
    have ha : a.natAbs \\u2208 (843).divisors := by
      simp; use b.natAbs; rw [\\u2190Int.natAbs_mul, h]; rfl
    simp only [(by native_decide : (843 : \\u2115).divisors = { 1, 3, 281, 843 }), Finset.mem_insert,
      Finset.mem_singleton] at ha
    omega
  have enum : {n : \\u2115 | 0 < n \\u2227 \\u2203 k, k^2 = (\\u2191n^2 : \\u2124) + 85 * n + 2017} = \\u2191({168, 27} : Finset \\u2115) := by
    ext n
    simp only [Set.mem_setOf_eq, Finset.coe_insert, Finset.coe_singleton, Set.mem_insert_iff,
      Set.mem_singleton_iff]
    constructor <;> intro h
    \\u00b7 obtain \\u27e8a, k, s\\u27e9 := h
      have wkw : (2 * k - 2 * n - 85) * (2 * k + 2 * n + 85) = 843 := by ring_nf; linarith
      obtain ald | ald | ald | ald | ald | ald | ald | ald := ta wkw
      all_goals zify at *
      all_goals rw [ald] at wkw
      all_goals omega
    \\u00b7 obtain rfl | rfl | rfl | rfl := h
      all_goals norm_num
      exists 211
      exists 71
  rw [enum, finsum_mem_coe_finset]
  decide
```
</formal_proof>

### Example 2
<informal_statement>
Prove that for each positive integer $k$ there exists a number base $b$ along with $k$ triples of Fibonacci numbers $(F_u,F_v,F_w)$ such that when they are written in base $b$, their concatenation is also a Fibonacci number written in base $b$.
</informal_statement>

<informal_proof>
We choose the base $b = 1$ and indices $u = v = w = 1$. Then $F_1 = 1$ and $F_3 = 2$, so $F_1 \\cdot 1^{F_1} + F_1 \\cdot 1^{F_1} = 2 = F_3$. This works for all $k$ triples.
</informal_proof>

<formal_proof>
```lean4
import Mathlib

theorem number_theory_62389 (k : \\u2115) (hk : 0 < k) :
    \\u2203 b : \\u2115, \\u2200 i \\u2208 Finset.range k, \\u2203 u v w : \\u2115,
      Nat.fib u > 0 \\u2227 Nat.fib v > 0 \\u2227 Nat.fib w > 0 \\u2227
      Nat.fib (u + v + w) = Nat.fib u * b ^ (Nat.fib v) + Nat.fib v * b ^ (Nat.fib w) := by
  use 1
  intro i hi
  use 1; use 1; use 1
  aesop
```
</formal_proof>

### Example 3
<informal_statement>
Find prime numbers $p , q , r$ such that $p+q^2+r^3=200$. Give all the possibilities.
</informal_statement>

<informal_proof>
By parity, one of $p,q,r$ must be 2. Checking all cases with bounds $q \\le 13$, $r \\le 5$:
- $r=2$: $p + q^2 = 192$ yields $(167,5,2)$, $(71,11,2)$, $(23,13,2)$
- $q=2$: $p + r^3 = 196$ yields $(71,2,5)$
- $p=2$: no solutions
</informal_proof>

<formal_proof>
```lean4
import Mathlib

theorem number_theory_54583 (p q r : \\u2115):
p.Prime \\u2227 q.Prime \\u2227 r.Prime \\u2227 (p + q^2 + r^3 = 200)
\\u2194
(p = 167 \\u2227 q = 5 \\u2227 r = 2) \\u2228 (p = 71 \\u2227 q = 11 \\u2227 r = 2) \\u2228 (p = 23 \\u2227 q = 13 \\u2227 r = 2) \\u2228 (p = 71 \\u2227 q = 2 \\u2227 r = 5) := by
  constructor
  intro \\u27e8pp,qp,rp,h\\u27e9
  have pge2 := pp.two_le; have qge2 := qp.two_le; have rge2 := rp.two_le
  have : q \\u2264 15 := by by_contra qg; push_neg at qg; have : 15^2 < q^2 := by nlinarith; omega
  have : r \\u2264 6 := by by_contra rg; push_neg at rg; have : 6^3 < r^3 := by apply pow_lt_pow_left\\u2080; exact rg; norm_num; norm_num; omega
  have : Even p \\u2228 Even q \\u2228 Even r := by
    by_contra allodd; push_neg at allodd; simp at allodd
    obtain \\u27e8op,oq,or\\u27e9 := allodd
    have t4: Odd (p+q^2+r^3) := by exact Even.add_odd (Odd.add_odd op (Odd.pow oq)) (Odd.pow or)
    rw [h] at t4; have : \\u00ac Odd 200 := by decide; contradiction
  obtain ep | eq | er := this
  have pe := (pp.even_iff).mp ep; simp [pe] at h \\u22a2; interval_cases q <;> interval_cases r <;> norm_num at h \\u22a2
  have qe := (qp.even_iff).mp eq; simp [qe] at h \\u22a2; interval_cases p <;> interval_cases r <;> norm_num at h pp rp \\u22a2
  have re := (rp.even_iff).mp er; simp [re] at h \\u22a2; interval_cases p <;> interval_cases q <;> norm_num at h pp qp \\u22a2
  intro h
  obtain \\u27e8pe,qe,re\\u27e9 | \\u27e8pe,qe,re\\u27e9 | \\u27e8pe,qe,re\\u27e9 | \\u27e8pe,qe,re\\u27e9 := h <;>
  simp [pe,qe,re] <;> norm_num
```
</formal_proof>'''


def wrap_prompt_in_query(informal_statement, informal_proof, example_block):
    """Build user prompt. Same format for ZS and R5 (only example_block differs)."""
    return f'''
    You task is to take as input an informal proof in natural language and autoformalize it in Lean 4 with a header.
    Think step-by-step and ensure that the output formal theorem is compilabile with Lean 4 (version 4.15.0).

    {example_block}

    Here is the **actual** informal proof in natural language:
    <informal_statement>
    {informal_statement}
    </informal_statement>

    <informal_proof>
    {informal_proof}
    </informal_proof>

    Now first think step-by-step for the actual output and autoformalize it in Lean 4 with a header. Importantly, enclose the final formal proof in Lean 4 inside the following tags:

    <formal_proof>
    ```lean4
    (Provide your entire Lean 4 proof with header here)
    ```
    </formal_proof>
    '''


# ── Main inference loop ─────────────────────────────────────────────────────

def inference_on_dataset(args):
    from openai import OpenAI

    if not os.path.exists(args.eval_dir):
        os.makedirs(args.eval_dir)

    model_name = args.model_id.split("/")[-1].replace(".", "-")
    log_file = os.path.join(args.eval_dir, f"{model_name}-{args.method_tag}_LOG.txt")
    with open(log_file, "w", encoding="utf-8") as f:
        f.write("")

    list_of_data_dicts, final_save_path = prepare_jsonl_to_evaluation(
        args.dataset_path,
        args.eval_dir,
        model_name=model_name,
        method_tag=args.method_tag,
        num_samples=args.num_samples_per_task,
    )

    # Build example block (ZS: empty, FS: hardcoded examples)
    if args.use_examples_in_prompt:
        example_block = "Here are a few examples:\n" + HARDCODED_EXAMPLES + "\n"
    else:
        example_block = ""

    client = OpenAI(
        api_key="EMPTY",
        base_url=f"http://localhost:{args.port}/v1",
        timeout=3600,
    )

    data_num = 0
    for data_item in tqdm(list_of_data_dicts):
        data_num += 1
        if data_item["INFERENCE_DONE"] == "yes":
            with open(log_file, "a", encoding="utf-8") as f:
                f.write(f"==================NUM{data_num}==================\n\n")
            continue

        informal_statement = str(data_item["informal_statement"]).strip()
        informal_proof = str(data_item["informal_proof"]).strip()

        input_text = wrap_prompt_in_query(informal_statement, informal_proof, example_block)

        with open(log_file, "a", encoding="utf-8") as f:
            f.write(f"==================NUM{data_num}==================\n\n")
            f.write("informal_statement: \n" + informal_statement + "\n")
            f.write("informal_proof: \n" + informal_proof + "\n")

        MAX_RETRIES_PER_ITEM = 3
        for attempt in range(1, MAX_RETRIES_PER_ITEM + 1):
            try:
                chat_responses = client.chat.completions.create(
                    model=args.model_id,
                    messages=[
                        {
                            "role": "system",
                            "content": "You are an expert in mathematics. Your task is to convert informal, natural-language proofs into correct Lean 4 formalizations.",
                        },
                        {"role": "user", "content": input_text},
                    ],
                    n=args.num_samples_per_task,
                    max_tokens=12000,
                    temperature=0.6,
                    top_p=0.95,
                )
                model_outs = [c.message.content for c in chat_responses.choices]
                break
            except Exception as e:
                with open(log_file, "a", encoding="utf-8") as f:
                    f.write(f"[RETRY] Attempt {attempt}/{MAX_RETRIES_PER_ITEM} failed: {e}\n")
                if attempt == MAX_RETRIES_PER_ITEM:
                    with open(log_file, "a", encoding="utf-8") as f:
                        f.write(f"[SKIP] All retries exhausted for NUM{data_num}, filling with error placeholders.\n")
                    model_outs = [None] * args.num_samples_per_task
                else:
                    import time
                    time.sleep(10)

        assert len(model_outs) == args.num_samples_per_task

        responses = []
        for model_out_idx, model_out in enumerate(model_outs):
            if model_out is not None:
                parsed_model_out = extract_fl_proof(model_out)
                if len(parsed_model_out.strip()) == 0:
                    parsed_model_out = "ERROR [No text within tags]"
            else:
                parsed_model_out = "ERROR [No output returned]"
            responses.append(parsed_model_out)
            with open(log_file, "a", encoding="utf-8") as f:
                f.write(f"Generated output# {model_out_idx}/{args.num_samples_per_task}:\n")
                f.write(f"rawModelOut: {model_out}\n")
                f.write(f"parsedModelOut: {parsed_model_out}\n")

        for trial_num in range(1, args.num_samples_per_task + 1):
            data_item[f"LLM_Output#{trial_num}"] = responses[trial_num - 1]
        data_item["INFERENCE_DONE"] = "yes"
        save_jsonl(list_of_data_dicts, final_save_path)

        with open(log_file, "a", encoding="utf-8") as f:
            f.write("Wrote all LLM outputs!!\n")
            f.write("=====================================\n\n")

    print("Inference and saving complete.")


if __name__ == "__main__":
    inference_on_dataset(parse_args())
