import argparse
import json
import logging
import re
import resource
import shutil
import subprocess
import tarfile
import tempfile
import time
from pathlib import Path

from tqdm import tqdm

code_block_pattern = re.compile(
    r"```(?:isabelle|lean4?|rocq|coq)?\s*(.*?)```", re.DOTALL | re.IGNORECASE
)
rocq_proof_placeholder_pattern = re.compile(r"Admitted\.", re.DOTALL)
rocq_comment_pattern = re.compile(r"\(\*.*?\*\)", re.DOTALL)


def setup_logging(log_to_file: bool):
    if log_to_file:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            filename="rocq_proof_verification.log",
            filemode="w",
        )
        print(f"Logging to rocq_proof_verification.log")


def limit_memory():
    resource.setrlimit(resource.RLIMIT_DATA, (32 * 1024**3, 32 * 1024**3))


def copy_benchmark_files(
    benchmark_dir: Path | str, why3_dir: Path | str | None = None
) -> tempfile.TemporaryDirectory:
    # benchmark-no-lake.tar contains frama_c, pearl directories without lake files
    # it also includes dune-workspace, and why3 files
    benchmark_dir = Path(benchmark_dir)
    tar_file = benchmark_dir / "benchmark-no-lake.tar"

    temp_dir = tempfile.TemporaryDirectory()

    if tar_file.exists():
        print(f"Extracting {tar_file} to {temp_dir.name}...")
        with tarfile.open(tar_file, "r") as tar:
            tar.extractall(path=temp_dir.name)
    else:
        if why3_dir is None:
            raise ValueError(
                "why3_dir must be provided if benchmark-no-lake.tar does not exist."
            )
        for dir_name in ["frama_c", "pearl"]:
            src_dir = benchmark_dir / dir_name
            dst_dir = Path(temp_dir.name) / dir_name
            print(f"Copying {src_dir} to {dst_dir}...")
            shutil.copytree(src_dir, dst_dir)
        # write dune-workspace file
        dune_workspace_path = Path(temp_dir.name) / "dune-workspace"
        dune_workspace_path.write_text("(lang dune 3.0)\n\n(context default)")
        # copy why3 files
        shutil.copytree(why3_dir, Path(temp_dir.name) / "generation")
        # replace Why3/Base.v with an actual file instead of a symlink
        project_root = why3_dir.parent
        why3_base_src = project_root / "lib" / "Rocq4Why3" / "Why3" / "Base.v"
        temp_why3_path = Path(temp_dir.name) / "generation" / "rocq" / "Why3" / "Base.v"
        temp_why3_path.unlink(missing_ok=True)
        temp_why3_path.write_text(why3_base_src.read_text())

    return temp_dir


def remove_rocq_comments(rocq_code: str):
    output = []
    comment_level = 0  # -1 means in '--' comment, >= 0 means in '/-' comment
    for i, c in enumerate(rocq_code):
        if c == "(" and len(rocq_code) > i + 1 and rocq_code[i + 1] == "*":
            comment_level += 1
        if comment_level == 0:
            output.append(c)
        elif c == ")" and i > 0 and rocq_code[i - 1] == "*":
            comment_level -= 1
    return "".join(output)


def get_problem_id(record: dict) -> str:
    if "header" in record:
        path: str = record["header"].replace("lean_standalone", "rocq")
        data_why3_index = path.index("data/why3/")
        relative_path = path[data_why3_index + len("data/why3/") :]
    elif "index" in record:
        absolute_path: str = record["index"]
        data_why3_index = absolute_path.index("data/why3/")
        relative_path = absolute_path[data_why3_index + len("data/why3/") :]
    else:
        raise ValueError("Record must contain either 'header' or 'index' field.")

    problem_id = relative_path.removesuffix(".v").replace("/", ".")
    return problem_id


def get_rocq_problem_path(problem_id: str, base_dir: str | Path) -> Path:
    base_dir = Path(base_dir)
    if not base_dir.exists():
        raise ValueError(f"Base directory {base_dir} does not exist.")
    if not problem_id.startswith("frama_c.") and not problem_id.startswith("pearl."):
        raise ValueError(
            f"Problem ID {problem_id} must start with 'frama_c.' or 'pearl.'."
        )
    relative_path = problem_id.replace(".", "/") + ".v"
    full_path = base_dir / relative_path

    return full_path


def remove_think_blocks(response: str) -> str:
    return re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL | re.IGNORECASE)


def extract_rocq_proof_body(rocq_code: str) -> str:
    # case 1. starts with "Proof." and ends with "Qed."
    if pair_match := re.search(r"Proof\..*Qed\.", rocq_code, re.DOTALL):
        return pair_match.group(0)

    # case 2. starts ":= by"
    if by_match := re.search(r":=\s*by\s+.*\.", rocq_code, re.DOTALL):
        return by_match.group(0)

    return ""


def extract_rocq_code(response: str, header: str):
    # NOTE this function may raise three types of ValueError:
    # 1. No ':= sorry' pattern found in lean file to replace with proof body
    # 2. No lemma/theorem line found in header.
    # 3. Proof body contains 'sorry' or 'admit'.
    # It WILL NOT raise error if no code block is found in the response.
    # Rather, it will try to interpret the whole response as the proof body,
    # and it may fail later during the build process.

    # 1. make sure the header contains the placeholder
    place_holder_matches = list(rocq_proof_placeholder_pattern.finditer(header))
    if not place_holder_matches:
        raise ValueError(
            f"No 'Admitted.' pattern found in rocq file to replace with proof body"
        )

    # 2. split the header into two parts: before the placeholder, after the placeholder
    start, end = place_holder_matches[-1].span()
    starter = header[:start]
    tail = header[end:]

    # 3. extract the lemma line in the header
    lemma_lines = list(re.finditer(r"^(Lemma|Theorem).*", header, re.MULTILINE))
    if not lemma_lines:
        raise ValueError("No lemma/theorem line found in header.")
    lemma_line = lemma_lines[-1].group(0).strip()[:-2]  # remove trailing '.'

    # 4. extract the last code block from the response
    # NOTE remove any <think>...</think> blocks
    response = remove_think_blocks(response)
    # 4.1 we first try to extract the last code block wrapped in a pair of ```
    if matches := list(code_block_pattern.finditer(response)):
        extracted_code = matches[-1].group(1).strip()
    # 4.2 if no code block is found, we try to extract the text after the last lemma line
    else:
        last_lemma_start_index = response.rfind(lemma_line)
        extracted_code = response[last_lemma_start_index:].strip()
        # 4.3 if no lemma line is found, we take the whole response as the extracted code
        # NOTE this may contain extra text before or after the actual code
        # but we still use it and the build may fail
        if last_lemma_start_index == -1:
            extracted_code = response.strip()

    # 5. extract the proof body from the extracted code block
    proof_body = extract_rocq_proof_body(extracted_code)
    # 5.1 we first try to extract the text after the first occurrence of ":="
    if not proof_body:
        proof_body = (
            extracted_code.removesuffix("```")
            .removesuffix("```rocq")
            .removesuffix("```coq")
            .strip()
        )

    # 6. remove comments and check if the proof body contains "sorry" or "admit"
    proof_body = remove_rocq_comments(proof_body).strip()
    if re.search(r"\b(sorry|admit|Admitted)\b", proof_body):
        raise ValueError("Proof body contains 'sorry', 'admit', or 'Admitted'.")

    # 7. construct the new code by replacing the placeholder with the proof body
    if proof_body.startswith("Proof."):
        new_code = starter + proof_body + tail
    else:
        new_code = starter + f" {proof_body}" + tail

    return new_code.strip()


def build(
    full_name: str, cwd: str, timeout_seconds: float = 600
) -> tuple[int, str, str, float]:
    v_file = get_rocq_problem_path(full_name, base_dir=cwd)
    vo_target = v_file.relative_to(cwd).with_suffix(".vo")
    command = ["dune", "build", str(vo_target)]
    logging.info(f"Command: {' '.join(command)}")

    try:
        start = time.time()
        result = subprocess.run(
            command,
            check=True,
            text=True,
            capture_output=True,
            cwd=cwd,
            timeout=timeout_seconds,
            preexec_fn=limit_memory,
        )
        end = time.time()
        return 1, result.stdout, result.stderr, end - start
    except subprocess.CalledProcessError as e:
        logging.info(f"Error occurred while building {full_name}:")
        logging.info(f"Return code: {e.returncode}")
        logging.info(f"Stdout: {e.stdout}")
        logging.info(f"Stderr: {e.stderr}")
        return -1, e.stdout, e.stderr, -float("inf")
    except subprocess.TimeoutExpired as e:
        return -124, e.stdout, e.stderr, timeout_seconds


def batch_build(
    problem_id_list: list[str],
    cwd: str,
    output_file_path: Path,
    timeout_seconds: int = 600,
) -> dict:
    if not output_file_path.exists():
        output_file_path.write_text(
            json.dumps(
                {
                    "build_result": {},
                    "build_time_seconds": {},
                    "excluded_errors": [],
                    "missing_problems": [],
                }
            )
        )

    existing_data: dict = json.loads(output_file_path.read_text())
    existing_pid = set(existing_data.get("build_result", {}).keys())
    problem_id_list = [pid for pid in problem_id_list if pid not in existing_pid]

    if not problem_id_list:
        return {
            pid: {
                "status": existing_data["build_result"][pid],
                "duration": existing_data["build_time_seconds"][pid],
            }
            for pid in existing_pid
        }

    if existing_pid:
        print(f"Resuming to build remaining {len(problem_id_list)} problems...")
    existing_records = {
        pid: {
            "status": existing_data["build_result"][pid],
            "duration": existing_data["build_time_seconds"][pid],
        }
        for pid in existing_pid
    }

    new_records = {}
    pass_num = 0
    # update tqdm to show progress
    for pid in (pbar := tqdm(problem_id_list)):
        status, stdout, stderr, duration = build(pid, cwd, timeout_seconds)
        new_records[pid] = {
            "status": status,
            "stdout": stdout,
            "stderr": stderr,
            "duration": duration,
        }
        pass_num += status == 1
        pbar.set_description(f"Build Pass: {pass_num}/{len(new_records)}")

        # periodically save intermediate results to the output file
        if len(new_records) % 10 == 0:
            existing_data: dict = json.loads(output_file_path.read_text())
            existing_data["build_result"].update(
                {full_name: res["status"] for full_name, res in new_records.items()}
            )
            existing_data["build_time_seconds"].update(
                {full_name: res["duration"] for full_name, res in new_records.items()}
            )
            output_file_path.write_text(json.dumps(existing_data))

    all_records = {**existing_records, **new_records}

    return all_records


def pass_at_n(records: list[dict], n: int) -> int:
    solved = set(
        pid
        for res in records[:n]
        for pid, status in res["build_result"].items()
        if status == 1
    )
    return len(solved)


def get_stats(model_name: str, total_responses: int = 8) -> dict:
    with open("test_set.rocq.lst", "r") as f:
        test_set_list = [line.strip() for line in f if line.strip()]
    all_problems = [get_problem_id({"header": sample}) for sample in test_set_list]

    json_file_paths = [
        Path(
            f"evaluation/raw/{model_name}/rocq-pass{total_responses}/"
            f"{model_name}_rocq_{idx}_{total_responses}_verification_result.json"
        )
        for idx in range(total_responses)
    ]
    verification_results = {
        idx: json.loads(p.read_text())
        for idx, p in enumerate(json_file_paths)
        if p.exists()
    }
    print("Found", len(verification_results), "evaluation records")
    trial_num = len(verification_results)
    tried_problems = set(
        pid
        for result in verification_results.values()
        for pid in result["build_result"].keys()
    )
    solved = set(
        pid
        for result in verification_results.values()
        for pid, res in result["build_result"].items()
        if res == 1
    )
    print(f"Unique solved problems: {len(solved)} ({len(tried_problems)} tried)")
    print(f"Total problems: {len(all_problems)}")
    pass_at_k = {}
    for n in [1, 4, 8]:
        if n > trial_num:
            break
        pass_n = pass_at_n(list(verification_results.values()), n) / len(all_problems)
        pass_at_k[n] = pass_n

    all_problems_list = sorted(all_problems)
    aggregated_build_results = {
        problem: [-1] * trial_num for problem in all_problems_list
    }
    aggregated_duration = {problem: [0.0] * trial_num for problem in all_problems_list}
    duration_of_passed = []
    for idx, result in verification_results.items():
        for full_name, res in result["build_result"].items():
            if full_name not in aggregated_build_results:
                continue
            aggregated_build_results[full_name][idx] = res
            aggregated_duration[full_name][idx] = result["build_time_seconds"][
                full_name
            ]
        for full_name, dur in result["build_time_seconds"].items():
            if full_name not in aggregated_build_results:
                continue
            aggregated_duration[full_name][idx] = dur
            if dur > 0 and result["build_result"][full_name] == 1:
                duration_of_passed.append(dur)
        for full_name in result["excluded_errors"]:
            if full_name not in aggregated_build_results:
                continue
            aggregated_build_results[full_name][idx] = -2  # excluded error
            aggregated_duration[full_name][idx] = 0.0

    avg_duration = (
        sum(duration_of_passed) / len(duration_of_passed) if duration_of_passed else 0.0
    )
    print(f"Average duration of passed builds: {avg_duration:.2f}s")

    output_dir = Path(f"evaluation/raw/{model_name}")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_filename = f"{model_name}_rocq_{total_responses}_pass{total_responses}.json"
    output_path = output_dir / output_filename
    aggregated_result = {
        "model": model_name,
        "total_problems": len(all_problems),
        "tried_problems": len(tried_problems),
        "solved_problems": len(solved),
        "pass_at_k": pass_at_k,
        "trial_num": trial_num,
        "average_duration_of_passed_builds_seconds": avg_duration,
        "build_results": aggregated_build_results,
        "durations": aggregated_duration,
    }
    output_path.write_text(json.dumps(aggregated_result))
    print(f"Aggregated results saved to {output_path}")


def get_response_file_path(
    model_name: str, total_responses: int, version: str = "0922"
) -> Path:
    return Path(
        f"data/why3/responses/{version}/responses-{model_name}-rocq-{total_responses}.jsonl"
    )


def eval_rocq_only_hammer(
    temp_dir: tempfile.TemporaryDirectory, test_set: set[str], timeout: int
):
    pid_to_build = set()
    for pid in test_set:
        temp_file_path = get_rocq_problem_path(pid, base_dir=temp_dir.name)
        formal_statement_header = temp_file_path.read_text()
        proof = "```coq\nProof.\n  hammer.\nQed.\n```"
        combined_code = extract_rocq_code(proof, formal_statement_header)
        combined_code = (
            "From Hammer Require Import Hammer.\nSet Hammer ATPLimit 30.\n"
            + combined_code
        )
        temp_file_path.write_text(combined_code)
        dune_file_path = temp_file_path.parent / "dune"
        if dune_file_path.exists():
            dune_contents = dune_file_path.read_text()
            if " Hammer " not in dune_contents:
                dune_contents = dune_contents.replace("(theories ", "(theories Hammer ")
                dune_file_path.write_text(dune_contents)
        pid_to_build.add(pid)

    output_file_path = Path("evaluation/raw/rocq_hammer_only_results.json")
    verification_result = {}
    build_result = batch_build(
        sorted(pid_to_build), temp_dir.name, output_file_path, timeout_seconds=timeout
    )

    verification_result["build_result"] = {
        full_name: res["status"] for full_name, res in build_result.items()
    }
    verification_result["build_time_seconds"] = {
        full_name: res["duration"] for full_name, res in build_result.items()
    }
    output_file_path.write_text(json.dumps(verification_result))
    print(f"Verification result saved to {output_file_path}")


def main():
    args = parse_args()

    model_name = args.model
    total_responses = args.total_responses
    print(f"Model name: {model_name} (pass@{total_responses})")

    if args.stats:
        return get_stats(model_name, total_responses)

    # setup logging
    setup_logging(args.log)

    timeout = args.timeout
    response_index = args.index
    if response_index is None:
        raise ValueError("Response index must be specified when not showing stats.")
    print(f"Response index: {response_index} / {total_responses}")

    output_dir = Path(f"evaluation/raw/{model_name}/rocq-pass{total_responses}")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_filename = (
        f"{model_name}_rocq_{response_index}_{total_responses}_verification_result.json"
    )
    output_file_path = output_dir / output_filename
    print(f"Output will be saved to {output_file_path}")
    if args.no_resume and output_file_path.exists():
        print(f"Output file {output_file_path} will be overwritten")
        output_file_path.unlink()

    # load the test set list
    with open("test_set.rocq.lst", "r") as f:
        test_set_list = [line.strip() for line in f if line.strip()]
    test_set = {get_problem_id({"header": sample}) for sample in test_set_list}
    print(f"Loaded {len(test_set_list)} test entries from test_set.rocq.lst")

    # load the response data
    response_file_path = get_response_file_path(model_name, total_responses)
    response_data = [
        json.loads(line) for line in response_file_path.read_text().splitlines() if line
    ]
    print(f"Loaded {len(response_data)} response entries from {response_file_path}")

    # copy benchmark directories to the verification directory
    temp_dir = copy_benchmark_files("data/why3")

    if args.only_hammer:
        return eval_rocq_only_hammer(temp_dir, test_set, timeout)

    pid_to_build = set()
    pid_to_exclude = set()
    pid_not_in_test_set = set()
    for entry in response_data:
        # for backward compatibility, support both "response" and "responses" fields
        response = entry.get("response", entry.get("responses", []))[response_index]
        problem_id = get_problem_id(entry)

        if problem_id not in test_set:
            pid_not_in_test_set.add(problem_id)
            continue

        original_file_path = get_rocq_problem_path(problem_id, base_dir="data/why3")
        formal_statement_header = original_file_path.read_text()
        temp_file_path = get_rocq_problem_path(problem_id, base_dir=temp_dir.name)

        try:
            combined_code = extract_rocq_code(response, formal_statement_header)
            temp_file_path.write_text(combined_code)
            # hold: problem_id in test_set
            pid_to_build.add(problem_id)
        except ValueError:
            # hold: problem_id in test_set, but no valid proof found in response
            pid_to_exclude.add(problem_id)

    pid_missing = test_set - pid_to_build - pid_to_exclude
    print(f"Collected {len(pid_to_build)} problems to build.")
    print(f"Excluded {len(pid_to_exclude)} problems with no valid proof in response.")

    if pid_not_in_test_set:
        print(
            f"Warning: {len(pid_not_in_test_set)} problem IDs in responses but not in test set."
        )
    if pid_missing:
        print(
            f"Warning: {len(pid_missing)} problem IDs in test set but not in responses."
        )

    verification_result = {
        "build_result": {},
        "build_time_seconds": {},
        "excluded_errors": sorted(pid_to_exclude),
        "missing_problems": sorted(pid_missing),
    }
    if not output_file_path.exists():
        output_file_path.write_text(json.dumps(verification_result))
    else:
        existing_data = json.loads(output_file_path.read_text())
        existing_data["excluded_errors"] = sorted(pid_to_exclude)
        existing_data["missing_problems"] = sorted(pid_missing)
        output_file_path.write_text(json.dumps(existing_data))

    build_result = batch_build(
        sorted(pid_to_build), temp_dir.name, output_file_path, timeout_seconds=timeout
    )

    verification_result["build_result"] = {
        full_name: res["status"] for full_name, res in build_result.items()
    }
    verification_result["build_time_seconds"] = {
        full_name: res["duration"] for full_name, res in build_result.items()
    }
    output_file_path.write_text(json.dumps(verification_result))
    print(f"Verification result saved to {output_file_path}")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Automate Rocq proof verification process."
    )
    parser.add_argument(
        "model",
        type=str,
        choices=[
            "DeepSeek-Prover-V2-671B",
            "DeepSeek-Prover-V2-7B",
            "DeepSeek-V3.1",
            "Goedel-Prover-V2-8B",
            "Goedel-Prover-V2-32B",
            "Qwen3-235B-A22B",
            "Qwen3-32B",
            "K2-think",
            "Kimina-Prover-72B",
            "GPT-o4-mini-high",
        ],
        help="Model name used to generate the responses.",
    )
    parser.add_argument(
        "--stats",
        action="store_true",
        help="Show statistics of the verification results.",
    )
    parser.add_argument(
        "--index",
        type=int,
        help="Response index to use (0-based).",
    )
    parser.add_argument(
        "--total_responses",
        type=int,
        default=8,
        help="Total number of responses generated for each problem.",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=600,
        help="Timeout in seconds for each build process (default: 600s).",
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Do not resume from existing output file; overwrite it.",
    )
    parser.add_argument(
        "--log",
        action="store_true",
        help="Enable detailed logging to rocq_proof_verification.log.",
    )
    parser.add_argument(
        "--only_hammer",
        action="store_true",
        help="Only run the hammer tactic without any generated proof.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()
