"""Script for evaluating the prover on theorems extracted by LeanDojo.
"""
import os
import uuid
import json
import pickle
import hashlib
import argparse
from loguru import logger
from typing import List, Tuple, Optional

from common import set_logger
from multilevel_isabelle.src.main.python.pisa_client import Theorem
from prover.proof_search_isa_multilevel import Status, DistributedProver


def _get_theorems(
    formal_system: str,
    data_path: str,
    split: str,
    file_path: str,
    full_name: str,
    name_filter: str,
    num_theorems: int,
) -> List[Theorem]:
    theorems = _isa_get_theorems_from_files(
            data_path,
            split,
            file_path,
            full_name,
            name_filter,
            num_theorems,
    )
    return theorems

def _isa_get_theorems_from_files(
    data_path: str,
    split: str,
    file_path: Optional[str],
    full_name: Optional[str],
    name_filter: Optional[str],
    num_theorems: Optional[int],
) -> List[Theorem]:
    data = json.load(open(os.path.join(data_path, f"{split}.json")))
    theorems = []

    for t in data:
        if file_path is not None and t["file_path"] != file_path:
            continue
        if full_name is not None and t["full_name"] != full_name:
            continue
        if name_filter is not None and not hashlib.md5(
            t["full_name"].encode()
        ).hexdigest().startswith(name_filter):
            continue
        theorems.append(file_path=IsaTheorem(t["file_path"], full_name=t["full_name"]), count=t["count"])
    theorems = sorted(
        theorems,
        key=lambda t: hashlib.md5(
            (str(t.file_path) + ":" + t.full_name).encode()
        ).hexdigest(),
    )
    if num_theorems is not None:
        theorems = theorems[:num_theorems]
    logger.info(f"{len(theorems)} theorems loaded from {data_path}")

    return theorems

def evaluate(
    formal_system: str, 
    data_path: str,
    jar_path: str,
    isabelle_path: str,
    exp_id: Optional[str] = None,
    split: str = "val",
    file_path: Optional[str] = None,
    full_name: Optional[str] = None,
    name_filter: Optional[str] = None,
    num_theorems: Optional[int] = None,
    ckpt_path: Optional[str] = None,
    indexed_corpus_path: Optional[str] = None,
    tactic: Optional[str] = None,
    module: Optional[str] = None,
    num_sampled_tactics: int = 64,
    timeout: int = 600,
    num_cpus: int = 1,
    with_gpus: bool = False,
    verbose: bool = False,
) -> float:
    set_logger(verbose)

    theorems = _get_theorems(
        formal_system, data_path, split, file_path, full_name, name_filter, num_theorems
    )
   
    repo = {
        "jar_path": jar_path,
        "isa_path": isabelle_path,
    }

    # Search for proofs using multiple concurrent provers.
    prover = DistributedProver(
        ckpt_path,
        indexed_corpus_path,
        tactic,
        module,
        num_cpus,
        with_gpus=with_gpus,
        timeout=timeout,
        num_sampled_tactics=num_sampled_tactics,
        debug=verbose,
    )
    results = prover.search_unordered(repo, theorems)

    # Calculate the result statistics.
    num_proved = num_failed = num_discarded = 0
    for r in results:
        if r is None:
            num_discarded += 1
        elif r.status == Status.PROVED:
            num_proved += 1
        else:
            num_failed += 1

    logger.info(
        f"Evaluation done! {num_proved} theorems proved, {num_failed} theorems failed, {num_discarded} non-theorems discarded"
    )

    if num_proved + num_failed == 0:
        pass_1 = float("nan")
    else:
        pass_1 = num_proved / (num_proved + num_failed)

    # Save the results.
    if exp_id is None:
        exp_id = str(uuid.uuid4())
    pickle_path = f"{exp_id}_results.pickle"
    pickle.dump(results, open(pickle_path, "wb"))
    logger.info(f"Results saved to {pickle_path}")

    return pass_1


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Script for evaluating the prover on theorems extracted by LeanDojo."
    )
    parser.add_argument(
        "--formal_system",
        type=str,
        choices=["lean", "isabelle"],
        default="lean",
        help="The formal system to use for evaluation.",
    )
    parser.add_argument(
        "--data-path",
        type=str,
        required=True,
        help="Path to the data extracted by LeanDojo/IsaDojo (e.g., data/leandojo_benchmark/random).",
    )
    parser.add_argument("--jar-path", type=str, help="Path to the jar file.", default="multilevel_isabelle/target/scala-2.13/multilevel_isabelle-assembly-0.1.0.jar")
    parser.add_argument("--isabelle-path", type=str, help="Path to the Isabelle installation.", default="/data2/wanghaiming/Isabelle2022")
    parser.add_argument("--exp-id", type=str, help="Experiment ID used for logging.")
    parser.add_argument(
        "--split",
        type=str,
        choices=["train", "val", "test"],
        default="val",
    )
    # `file_path`, `full_name`, `name_filter`, and `num_theorems` can be used to filter theorems.
    parser.add_argument("--file-path", type=str)
    parser.add_argument("--full-name", type=str, default='''"lemma var_order_string_le[sepref_import_param]:\n  \\<open>((<), var_order') \\<in> string_rel \\<rightarrow> string_rel \\<rightarrow> bool_rel\\<close>"''')
    parser.add_argument("--name-filter", type=str)
    parser.add_argument("--num-theorems", type=int)
    parser.add_argument(
        "--ckpt_path",
        type=str,
        help="Checkpoint of the tactic generator.",
        default="/hpc2hdd/home/zyang398/wanghaiming/MetaMath_run_ToDppixc/train_llm_neox_isabelle/checkpoints_isabelle_marchdata_multilevel_lr3e-4_wd0.1_epoch_10/checkpoint-34800"
    )
    parser.add_argument(
        "--indexed-corpus-path",
        type=str,
        help="Path to a pickled indexed corpus. Not required for models w/o retrieval.",
    )
    parser.add_argument("--tactic", type=str, help="The tactic to evaluate.")
    parser.add_argument("--module", type=str, help="The module to import the tactic.")
    parser.add_argument(
        "--num-sampled-tactics",
        type=int,
        default=8,
        help="Number of tactics to sample at each node during proof search.",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=600,
        help="Maximum number of seconds the proof search can take.",
    )
    parser.add_argument(
        "--num-cpus", type=int, default=1, help="The number of concurrent provers."
    )
    parser.add_argument(
        "--with-gpus", action="store_true", help="Use GPUs for proof search."
    )
    parser.add_argument(
        "--verbose", action="store_true", help="Set the logging level to DEBUG."
    )
    args = parser.parse_args()

    assert args.ckpt_path or args.tactic

    logger.info(f"PID: {os.getpid()}")
    logger.info(args)

    pass_1 = evaluate(
        args.formal_system,
        args.data_path,
        args.jar_path,
        args.isabelle_path,
        args.exp_id,
        args.split,
        args.file_path,
        args.full_name,
        args.name_filter,
        args.num_theorems,
        args.ckpt_path,
        args.indexed_corpus_path,
        args.tactic,
        args.module,
        args.num_sampled_tactics,
        args.timeout,
        args.num_cpus,
        args.with_gpus,
        args.verbose,
    )

    logger.info(f"Pass@1: {pass_1}")


if __name__ == "__main__":
    main()
