import argparse
import json
import random
import re
import resource
import shutil
import subprocess
import tarfile
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

from tqdm import tqdm

code_block_pattern = re.compile(
    r"```(?:isabelle|lean4?|rocq|coq)?\s*(.*?)```", re.DOTALL | re.IGNORECASE
)
lean_proof_placeholder_pattern = re.compile(r":=\s*sorry", re.DOTALL)
lean_comment_pattern = re.compile(r"--[^\n]*?(?=\n)|/-*.*?-/", re.DOTALL)


# some convention:
# A. problem_id
#   - a dot-separated string representing the relative path of a Lean file
#   - e.g., frama_c.contiki_memb.memb_free_Why3_ide_vcg.lean.memb_free_Why3_ide_VCmemb_free_loop_inv_established_goal3
#   - it always starts with either "frama_c." or "pearl.", and doesn't have the ".lean" suffix

# verification result dump file protocol:
# A. we use different status codes to indicate different results:
#   - code 1: build succeeded, the generated proof by NTP is correct
#   - code -1: build failed, the generated proof by NTP is incorrect
#   - code -2: excluded error, the vc was excluded from verification due to errors in generated proof
#   - code -124: build timed out (default to 600 seconds)
# B. the output json file for a batch of responses contains the following fields:
#   - "build_result": a dict mapping problem_id to status code
#   - "build_time_seconds": a dict mapping problem_id to build time in seconds
#   - "excluded_errors": a list of problem_ids that were excluded from verification due to errors in generated proof
# C. the output json file for a batch of responses is named as:
#   - {model_name}_{prover}_{response_index}_{total_responses}_verification_result.json
#   - e.g., DeepSeek-Prover-V2-671B_lean_13_24_verification_result.json
# D. the aggregated output json file for a model across multiple responses contains the following fields:
#   - "build_results": a dict mapping problem_id to a list of status codes for each response
#   - "durations": a dict mapping problem_id to a list of build times in seconds for each response
# E. the aggregated output json file for a model is named as:
#   - {model_name}_{prover}_pass{total_responses}.json
#   - e.g., DeepSeek-Prover-V2-671B_lean_pass8.json
# F. the output json files are stored in:
#   - for each batch of responses: evaluation/raw/{model_name}/{prover}-pass{total_responses}/
#   - for aggregated results: evaluation/raw/{model_name}/


def limit_memory():
    # limit memory to 32GB
    resource.setrlimit(resource.RLIMIT_AS, (32 * 1024**3, 32 * 1024**3))


def copy_benchmark_files(benchmark_dir: Path, dest_dir: Path):
    dirs_to_copy = ["frama_c", "pearl", ".lake"]
    if dest_dir.exists():
        return

    # extracting from benchmark.tar if it exists, this is usually faster
    tar_file = benchmark_dir / "benchmark.tar"
    if tar_file.exists():
        print(f"Extracting {tar_file} to {dest_dir}...")
        with tarfile.open(tar_file, "r") as tar:
            tar.extractall(path=dest_dir)
    # or if the "benchmark.tar" file does not exist, we copy the directories one by one
    else:
        for d in dirs_to_copy:
            src_dir = benchmark_dir / d
            dst_dir = dest_dir / d
            if not dst_dir.exists():
                print(f"Copying {src_dir} to {dst_dir}...")
                shutil.copytree(src_dir, dst_dir)

    # modify lakefile.toml to adjust the path to the generation/lean directory
    lake_toml_path = dest_dir / "lakefile.toml"
    if not lake_toml_path.exists():
        shutil.copy("data/why3/lakefile.toml", lake_toml_path)
    lake_toml_content = lake_toml_path.read_text()
    lake_toml_content = lake_toml_content.replace(
        '"../../generation/lean/"', '"../../../../generation/lean/"'
    )
    lake_toml_path.write_text(lake_toml_content)


def remove_lean_comments(lean_code: str):
    output = []
    comment_level = 0  # -1 means in '--' comment, >= 0 means in '/-' comment
    for i, c in enumerate(lean_code):
        if (
            comment_level == 0
            and c == "-"
            and len(lean_code) > i + 1
            and lean_code[i + 1] == "-"
        ):
            comment_level = -1
        elif (
            c == "/"
            and len(lean_code) > i + 1
            and lean_code[i + 1] == "-"
            and comment_level != -1
        ):
            comment_level += 1
        elif comment_level == -1 and c == "\n":
            comment_level = 0
        if comment_level == 0:
            output.append(c)
        elif c == "/" and i > 0 and lean_code[i - 1] == "-" and comment_level != -1:
            comment_level -= 1
    return "".join(output)


def get_problem_id(record: dict) -> str:
    # "header" data format for backward compatibility
    if "header" in record:
        path = record["header"].replace("lean_standalone", "lean")
        data_why3_index = path.index("data/why3/")
        relative_path = path[data_why3_index + len("data/why3/") :]
    elif "index" in record:
        absolute_path = record["index"]
        data_why3_index = absolute_path.index("data/why3/")
        relative_path = absolute_path[data_why3_index + len("data/why3/") :]
    else:
        raise ValueError("Record must contain either 'header' or 'index' field.")

    problem_id = relative_path.removesuffix(".lean").replace("/", ".")
    return problem_id


def get_lean_problem_path(problem_id: str, base_dir: str | Path) -> Path:
    base_dir = Path(base_dir)
    if not base_dir.exists():
        raise ValueError(f"Base directory {base_dir} does not exist.")
    if not problem_id.startswith("frama_c.") and not problem_id.startswith("pearl."):
        raise ValueError(
            f"Problem ID {problem_id} must start with 'frama_c.' or 'pearl.'."
        )
    relative_path = problem_id.replace(".", "/") + ".lean"
    full_path = base_dir / relative_path

    return full_path


def remove_think_blocks(response: str) -> str:
    return re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL | re.IGNORECASE)


def extract_lean_code(response: str, header: str):
    # NOTE this function may raise three types of ValueError:
    # 1. No ':= sorry' pattern found in lean file to replace with proof body
    # 2. No lemma/theorem line found in header.
    # 3. Proof body contains 'sorry' or 'admit'.
    # It WILL NOT raise error if no code block is found in the response.
    # Rather, it will try to interpret the whole response as the proof body,
    # and it may fail later during the build process.

    # 1. make sure the header contains the placeholder
    place_holder_matches = list(lean_proof_placeholder_pattern.finditer(header))
    if not place_holder_matches:
        raise ValueError(
            f"No ':= sorry' pattern found in lean file to replace with proof body"
        )

    # 2. split the header into two parts: before the placeholder, after the placeholder
    start, end = place_holder_matches[-1].span()
    starter = header[:start]
    tail = header[end:]

    # 3. extract the lemma line in the header
    lemma_lines = list(re.finditer(r"^(lemma|theorem).*", header, re.MULTILINE))
    if not lemma_lines:
        raise ValueError("No lemma/theorem line found in header.")
    lemma_line = lemma_lines[-1].group(0).strip()

    # 4. extract the last code block from the response
    # NOTE remove any <think>...</think> blocks
    response = remove_think_blocks(response)
    # 4.1 we first try to extract the last code block wrapped in a pair of ```
    if matches := list(code_block_pattern.finditer(response)):
        extracted_code = matches[-1].group(1).strip()
    # 4.2 if no code block is found, we try to extract the text after the last lemma line
    else:
        last_lemma_start_index = response.rfind(lemma_line)
        extracted_code = response[last_lemma_start_index:].strip()
        # 4.3 if no lemma line is found, we take the whole response as the extracted code
        # NOTE this may contain extra text before or after the actual code
        # but we still use it and the build may fail
        if last_lemma_start_index == -1:
            extracted_code = response.strip()

    # 5. extract the proof body from the extracted code block
    # 5.1 we first try to extract the text after the first occurrence of ":="
    if proof_body_match := re.search(r":=\s*(.*)", extracted_code, re.DOTALL):
        proof_body = proof_body_match.group(1).strip()
    # 5.2 if no ":=" is found, we try to remove the wrapping "```"
    else:
        proof_body = (
            extracted_code.removesuffix("```")
            .removesuffix("```lean4")
            .removesuffix("```lean")
            .strip()
        )

    # 6. remove comments and check if the proof body contains "sorry" or "admit"
    proof_body = remove_lean_comments(proof_body).strip()
    if re.search(r"\b(sorry|admit)\b", proof_body):
        raise ValueError("Proof body contains 'sorry' or 'admit'.")

    # 7. construct the new code by replacing the placeholder with the proof body
    new_code = starter + f":= {proof_body}" + tail

    return new_code.strip()


def build(
    full_name: str, cwd: str, timeout_seconds: float = 600
) -> tuple[int, str, str, float]:
    home_dir = Path("~").expanduser()
    lake_bin_path = home_dir / ".elan" / "bin" / "lake"
    command = [str(lake_bin_path), "build", full_name]

    try:
        start = time.time()
        result = subprocess.run(
            command,
            check=True,
            text=True,
            capture_output=True,
            cwd=cwd,
            timeout=timeout_seconds,
            preexec_fn=limit_memory,
        )
        end = time.time()
        return 1, result.stdout, result.stderr, end - start
    except subprocess.CalledProcessError as e:
        return -1, e.stdout, e.stderr, -float("inf")
    except subprocess.TimeoutExpired as e:
        return -124, e.stdout, e.stderr, timeout_seconds


def _batch_build(
    problem_id_list: list[str],
    cwd: str,
    output_file_path: Path,
    timeout_seconds: int = 600,
) -> dict:
    records = {}
    pass_num = 0
    # update tqdm to show progress
    for pid in (pbar := tqdm(problem_id_list)):
        status, _, _, duration = build(pid, cwd, timeout_seconds)
        records[pid] = {"status": status, "duration": duration}
        pass_num += status == 1
        pbar.set_description(f"Build Pass: {pass_num}/{len(records)}")

        # periodically save intermediate results to the output file
        if len(records) % 10 == 0:
            existing_data: dict = json.loads(output_file_path.read_text())
            existing_data["build_result"].update(
                {full_name: res["status"] for full_name, res in records.items()}
            )
            existing_data["build_time_seconds"].update(
                {full_name: res["duration"] for full_name, res in records.items()}
            )
            output_file_path.write_text(json.dumps(existing_data))

    return records


def _batch_build_parallel(
    problem_id_list: list[str],
    cwd: str,
    output_file_path: Path,
    timeout_seconds: int = 600,
    max_workers: int = 16,
) -> dict:
    records = {}
    pass_num = 0
    random.shuffle(problem_id_list)
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_pid = {
            executor.submit(build, pid, cwd, timeout_seconds): pid
            for pid in problem_id_list
        }
        for future in (
            pbar := tqdm(as_completed(future_to_pid), total=len(problem_id_list))
        ):
            pid = future_to_pid[future]
            try:
                status, _, _, duration = future.result()
                records[pid] = {"status": status, "duration": duration}
                pass_num += status == 1
            except Exception:
                records[pid] = {"status": -1, "duration": -float("inf")}
            pbar.set_description(f"Build Pass: {pass_num}/{len(records)}")

            # periodically save intermediate results to the output file
            if len(records) % 10 == 0:
                existing_data: dict = json.loads(output_file_path.read_text())
                existing_data["build_result"].update(
                    {full_name: res["status"] for full_name, res in records.items()}
                )
                existing_data["build_time_seconds"].update(
                    {full_name: res["duration"] for full_name, res in records.items()}
                )
                output_file_path.write_text(json.dumps(existing_data))

    return records


def batch_build(
    problem_id_list: list[str],
    cwd: str,
    output_file_path: Path,
    timeout_seconds: int = 600,
    use_parallel: bool = True,
    max_workers: int = 16,
) -> dict:
    if not output_file_path.exists():
        output_file_path.write_text(
            json.dumps(
                {
                    "build_result": {},
                    "build_time_seconds": {},
                    "excluded_errors": [],
                    "missing_problems": [],
                }
            )
        )

    existing_data: dict = json.loads(output_file_path.read_text())
    existing_pid = set(existing_data.get("build_result", {}).keys())
    problem_id_list = [pid for pid in problem_id_list if pid not in existing_pid]

    if not problem_id_list:
        return {
            pid: {
                "status": existing_data["build_result"][pid],
                "duration": existing_data["build_time_seconds"][pid],
            }
            for pid in existing_pid
        }

    if existing_pid:
        print(f"Resuming to build remaining {len(problem_id_list)} problems...")
    existing_records = {
        pid: {
            "status": existing_data["build_result"][pid],
            "duration": existing_data["build_time_seconds"][pid],
        }
        for pid in existing_pid
    }

    if use_parallel:
        new_records = _batch_build_parallel(
            problem_id_list, cwd, output_file_path, timeout_seconds, max_workers
        )
    else:
        new_records = _batch_build(
            problem_id_list, cwd, output_file_path, timeout_seconds
        )

    all_records = {**existing_records, **new_records}

    return all_records


def get_stats(model_name: str, total_responses: int = 8) -> dict:
    with open("test_set.lean.lst", "r") as f:
        test_set_list = [line.strip() for line in f if line.strip()]
    all_problems = [get_problem_id({"header": sample}) for sample in test_set_list]

    json_file_paths = [
        Path(
            f"evaluation/raw/{model_name}/lean-pass{total_responses}/"
            f"{model_name}_lean_{idx}_{total_responses}_verification_result.json"
        )
        for idx in range(total_responses)
    ]
    verification_results = {
        idx: json.loads(p.read_text())
        for idx, p in enumerate(json_file_paths)
        if p.exists()
    }
    print("Found", len(verification_results), "evaluation records")
    trial_num = len(verification_results)
    tried_problems = set(
        pid
        for result in verification_results.values()
        for pid in result["build_result"].keys()
    )
    solved = set(
        pid
        for result in verification_results.values()
        for pid, res in result["build_result"].items()
        if res == 1
    )
    print(f"Unique solved problems: {len(solved)} ({len(tried_problems)} tried)")
    print(f"Total problems: {len(all_problems)}")
    print(f"Pass@{trial_num}: {len(solved) / len(all_problems):.2%}")

    all_problems_list = sorted(all_problems)
    aggregated_build_results = {
        problem: [-1] * trial_num for problem in all_problems_list
    }
    aggregated_duration = {problem: [0.0] * trial_num for problem in all_problems_list}
    duration_of_passed = []
    for idx, result in verification_results.items():
        for full_name, res in result["build_result"].items():
            if full_name not in aggregated_build_results:
                continue
            aggregated_build_results[full_name][idx] = res
            aggregated_duration[full_name][idx] = result["build_time_seconds"][
                full_name
            ]
        for full_name, dur in result["build_time_seconds"].items():
            if full_name not in aggregated_build_results:
                continue
            aggregated_duration[full_name][idx] = dur
            if dur > 0 and result["build_result"][full_name] == 1:
                duration_of_passed.append(dur)
        for full_name in result["excluded_errors"]:
            if full_name not in aggregated_build_results:
                continue
            aggregated_build_results[full_name][idx] = -2  # excluded error
            aggregated_duration[full_name][idx] = 0.0

    avg_duration = (
        sum(duration_of_passed) / len(duration_of_passed) if duration_of_passed else 0.0
    )
    print(f"Average duration of passed builds: {avg_duration:.2f}s")

    output_dir = Path(f"evaluation/raw/{model_name}")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_filename = f"{model_name}_lean_{total_responses}_pass{total_responses}.json"
    output_path = output_dir / output_filename
    aggregated_result = {
        "model": model_name,
        "total_problems": len(all_problems),
        "tried_problems": len(tried_problems),
        "solved_problems": len(solved),
        "pass_at_k": len(solved) / len(all_problems) if all_problems else 0.0,
        "trial_num": trial_num,
        "average_duration_of_passed_builds_seconds": avg_duration,
        "build_results": aggregated_build_results,
        "durations": aggregated_duration,
    }
    output_path.write_text(json.dumps(aggregated_result))
    print(f"Aggregated results saved to {output_path}")


def get_response_file_path(
    model_name: str, total_responses: int, version: str = "0910"
) -> Path:
    return Path(
        f"data/why3/responses/{version}/responses-{model_name}-lean-{total_responses}.jsonl"
    )


def main():
    args = parse_args()

    model_name = args.model
    total_responses = args.total_responses
    print(f"Model name: {model_name} (pass@{total_responses})")

    if args.stats:
        return get_stats(model_name, total_responses)

    timeout = args.timeout
    max_workers = args.max_workers
    response_index = args.index
    if response_index is None:
        raise ValueError("Response index must be specified when not showing stats.")
    print(f"Response index: {response_index} / {total_responses}")

    use_parallel = max_workers > 0
    verification_dir = Path(f"data/why3/responses/{model_name}")
    output_dir = Path(f"evaluation/raw/{model_name}/lean-pass{total_responses}")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_filename = (
        f"{model_name}_lean_{response_index}_{total_responses}_verification_result.json"
    )
    output_file_path = output_dir / output_filename
    print(f"Output will be saved to {output_file_path}")
    if args.no_resume and output_file_path.exists():
        print(f"Output file {output_file_path} will be overwritten")
        output_file_path.unlink()

    # load the test set list
    with open("test_set.lean.lst", "r") as f:
        test_set_list = [line.strip() for line in f if line.strip()]
    test_set = {get_problem_id({"header": sample}) for sample in test_set_list}
    print(f"Loaded {len(test_set_list)} test entries from test_set.lean.lst")

    # load the response data
    response_file_path = get_response_file_path(model_name, total_responses)
    response_data = [
        json.loads(line) for line in response_file_path.read_text().splitlines() if line
    ]
    print(f"Loaded {len(response_data)} response entries from {response_file_path}")

    # copy benchmark directories to the verification directory
    copy_benchmark_files(Path("data/why3"), Path(verification_dir))

    # problem_ids from responses, and problem_ids from test set
    # 1. intersection: we will build these problems
    #  - 1.1 does not have any valid proof in the response: we will exclude these problems from build
    #        we will try our best to extract the proof body, as long as it does not contain "sorry" or "admit"
    #        we will record these problems in "excluded_errors" field in the output json file
    # 2. in responses but not in test set: we will warn about these problems, but not build them
    # 3. in test set but not in responses: we will warn about these problems, but not build them
    pid_to_build = set()
    pid_to_exclude = set()
    pid_not_in_test_set = set()
    for entry in response_data:
        # for backward compatibility, support both "response" and "responses" fields
        response = entry.get("response", entry.get("responses", []))[response_index]
        problem_id = get_problem_id(entry)

        if problem_id not in test_set:
            pid_not_in_test_set.add(problem_id)
            continue

        original_file_path = get_lean_problem_path(problem_id, base_dir="data/why3")
        formal_statement_header = original_file_path.read_text()
        temp_file_path = get_lean_problem_path(problem_id, base_dir=verification_dir)

        try:
            combined_code = extract_lean_code(response, formal_statement_header)
            temp_file_path.write_text(combined_code)
            # hold: problem_id in test_set
            pid_to_build.add(problem_id)
        except ValueError:
            # hold: problem_id in test_set, but no valid proof found in response
            pid_to_exclude.add(problem_id)

    pid_missing = test_set - pid_to_build - pid_to_exclude
    print(f"Collected {len(pid_to_build)} problems to build.")
    print(f"Excluded {len(pid_to_exclude)} problems with no valid proof in response.")

    if pid_not_in_test_set:
        print(
            f"Warning: {len(pid_not_in_test_set)} problem IDs in responses but not in test set."
        )
    if pid_missing:
        print(
            f"Warning: {len(pid_missing)} problem IDs in test set but not in responses."
        )

    verification_result = {
        "build_result": {},
        "build_time_seconds": {},
        "excluded_errors": sorted(pid_to_exclude),
        "missing_problems": sorted(pid_missing),
    }
    if not output_file_path.exists():
        output_file_path.write_text(json.dumps(verification_result))
    else:
        existing_data = json.loads(output_file_path.read_text())
        existing_data["excluded_errors"] = sorted(pid_to_exclude)
        existing_data["missing_problems"] = sorted(pid_missing)
        output_file_path.write_text(json.dumps(existing_data))

    # here we comes the main build process
    build_result = batch_build(
        sorted(pid_to_build),
        verification_dir,
        output_file_path,
        timeout_seconds=timeout,
        use_parallel=use_parallel,
        max_workers=max_workers,
    )

    verification_result["build_result"] = {
        full_name: res["status"] for full_name, res in build_result.items()
    }
    verification_result["build_time_seconds"] = {
        full_name: res["duration"] for full_name, res in build_result.items()
    }
    output_file_path.write_text(json.dumps(verification_result))
    print(f"Verification result saved to {output_file_path}")

    if not args.no_cleanup:
        print(f"Cleaning up temporary directory {verification_dir}...")
        shutil.rmtree(verification_dir)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Automate Lean proof verification process."
    )
    parser.add_argument(
        "model",
        type=str,
        choices=[
            "DeepSeek-Prover-V2-671B",
            "DeepSeek-Prover-V2-7B",
            "Goedel-Prover-V2-8B",
            "Goedel-Prover-V2-32B",
            "Qwen3-235B-A22B",
            "Qwen3-32B",
            "GPT-o4-mini-high",
        ],
        help="Model name used to generate the responses.",
    )
    parser.add_argument(
        "--stats",
        action="store_true",
        help="Show statistics of the verification results.",
    )
    parser.add_argument(
        "--index",
        type=int,
        help="Response index to use (0-based).",
    )
    parser.add_argument(
        "--total_responses",
        type=int,
        default=8,
        help="Total number of responses generated for each problem.",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=600,
        help="Timeout in seconds for each build process (default: 600s).",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=16,
        help="Maximum number of parallel workers for building; 0 means no parallelism (default: 16).",
    )
    parser.add_argument(
        "--no_cleanup",
        action="store_true",
        help="Do not clean up the temporary verification directory.",
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Do not resume from existing output file; overwrite it.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()
