from datetime import datetime
import json
import multiprocessing as mp
import time
from typing import Any, List, Union

from pisa_client import IsabelleFatalError, IsabelleRuntimeError, initialise_env, PisaEnv
from utils.pisa_server_control import start_server, close_server
from utils.filters import get_split
from queue import Queue
import glob
import os
import pytz
import argparse
from tqdm import tqdm
from func_timeout import FunctionTimedOut
import logging
import threading
import traceback

DEBUG = False
N_RETRY = 5

def get_escaped_name(problem_name, existed_dict):
    cnt = 1
    escaped_name = f"{cnt}$#${problem_name}"
    while escaped_name in existed_dict:
        cnt += 1
        escaped_name = f"{cnt}$#${problem_name}"
    return escaped_name

def clean_up_isabelle_env(env: PisaEnv, process_id: int):
    if env is not None:
        env.exit()
        del env
    if process_id is not None:
        close_server(process_id)

class MultilevelDataExtractor:
    
    def __init__(self, rank, debug) -> None:
        self.logger = logging.getLogger(f'process-{rank}')
        self.debug = debug

    def analyse_file_string(self, whole_file_string):
        transitions = whole_file_string.split("<\TRANSEP>")
        state_action_proof_level_tuples = list()
        problem_names = list()
        for transition in transitions:
            if not transition:
                continue
            else:
                state, action, proof_level = transition.split("<\STATESEP>")
                hammer_results = "NA"
            state = state.strip()
            action = action.strip()
            proof_level = int(proof_level.strip())
            if (action.startswith("lemma") or action.startswith("theorem")) and not action.startswith("lemmas"):
                problem_names.append(action)
            state_action_proof_level_tuples.append((state, action, proof_level, hammer_results))

        problems = self.process_transitions(problem_names, state_action_proof_level_tuples)
        return problems

    def process_transitions(self, problem_names, transitions):
        # Leave in only the actual problem names
        problem_names = [problem_name.strip() for problem_name in problem_names]
        problem_names = [problem_name for problem_name in problem_names if (problem_name.startswith("lemma") or problem_name.startswith("theorem")) and not problem_name.startswith("lemmas")]

        if len(problem_names) == 0:
            return []
        # Filter out comments
        good_transitions = []
        for transition in transitions:
            transition_text = transition[1].strip()
            if transition_text.startswith("(*") and transition_text.endswith("*)"):
                continue
            if (transition_text.startswith("text \\<open>") or transition_text.startswith("txt \\<open>")) and transition_text.endswith("\\<close>"):
                continue
            good_transitions.append(transition)

        # Filter out all the transitions that are not in proofs
        current_problem_name = None
        problem_name_to_transitions = {}
        proof_open = False
        for transition in good_transitions:
            _, transition_text, proof_level, _ = transition
            # print(transition_text, proof_level)
            if transition_text in problem_names:
                current_problem_name = get_escaped_name(transition_text, problem_name_to_transitions)
                assert proof_level == 0, transition
                problem_name_to_transitions[current_problem_name] = [transition]
                proof_open = True
            elif proof_level == 0:
                proof_open = False
                continue
            elif proof_open:
                problem_name_to_transitions[current_problem_name].append(transition)
            else:
                pass
        
        # add assertions
        assert None not in problem_name_to_transitions
        assert len(problem_name_to_transitions.keys()) == len(problem_names)
        escaped_problem_names = set()
        for problem_name in problem_names:
            escaped_problem_names.add(get_escaped_name(problem_name, escaped_problem_names))
        assert set(problem_name_to_transitions.keys()) == escaped_problem_names

        problems = []
        for problem_name, transitions in problem_name_to_transitions.items():
            full_proof_text = "\n".join([transition[1] for transition in transitions])
            count, problem_name = problem_name.split("$#$")
            split = get_split(problem_name)
            problems.append(
                {
                    "problem_name": problem_name,
                    "full_proof_text": full_proof_text,
                    "transitions": transitions,
                    "split": split,
                    "count": int(count)
                }
            )
        return problems

    def prepare_initial_state(self, env: PisaEnv, problem):
        self.init_state = env.proceed_to_line(problem["problem_name"], "after", problem["count"])
        assert len(self.init_state) != 0
        env.initialise()

    def get_transitions(self, env: PisaEnv, proof_steps: List[str]):
        env.clone_to_new_name("temp_0")
        result_transitions = {}
        state_before = self.init_state
        for idx, step in enumerate(proof_steps[1:]):
            actual_step = "_".join(step.split("_")[:-1])
            state_string, reward, done, _, proof_level = env.step_to_top_level_state(
                actual_step, f"temp_{idx}", f"temp_{idx+1}", return_proof_level=True, tactic_time=500000)
            if len(state_string) > 0:
                assert state_string.strip().startswith("proof"), f"error in verifying the proof steps. \nStep: {step}\n" + \
                    f"state: {state_string} \n"
                    # f"the proof:\n" + "\n".join(proof_steps)

            annotated_step = {
                "tactic": actual_step,
                "tactic_mark": step,
                "state_before": state_before,
                "state_after": state_string,
                "proof_level": proof_level,
                "reward": reward,
                "done": done,
                "history": proof_steps[:idx+1],
            }
            result_transitions[step] = annotated_step
            state_before = state_string
        assert done is True and reward == 1
        return result_transitions

    def verify_the_steps(self, env: PisaEnv, proof_steps: List[str]):
        # print("Number of threads", threading.active_count())
        env.clone_to_new_name("temp_0")
        for idx, step in enumerate(proof_steps[1:]):
            state_string, reward, done, _, proof_level = env.step_to_top_level_state(
                step, f"temp_{idx}", f"temp_{idx+1}", return_proof_level=True, tactic_time=500000)
            
            # retry with more time
            # if state_string.strip().startswith("Step error: Timeout after"):
            #     state_string, reward, done, _, proof_level = env.step_to_top_level_state(
            #     step, f"temp_{idx}", f"temp_{idx+1}", return_proof_level=True, tactic_time=500000)
            if len(state_string) > 0:
                # if state_string.strip().startswith("Step error: Timeout after"):
                #     # if idx == len(proof_steps[1:]) - 1:
                #     #     self.logger.info("Timeout after the last step, but the proof is done.")
                #     #     return True
                #     # else:
                #     #     self.logger.info("Timeout before the last step.")
                #     #     assert False, f"timeout in verifying the proof steps. \nStep: {step}\n" 
                    
                #     # retry with more time:

                # else:
                assert state_string.strip().startswith("proof"), f"error in verifying the proof steps. \nStep: {step}\n" + \
                    f"state: {state_string} \n" + \
                    f"the proof:\n" + "\n".join(proof_steps)
        assert done is True and reward == 1
        return True
    
    def collect_steps(self, env: PisaEnv, proof_steps: List[str]):
        env.clone_to_new_name("temp_0")
        output_dict = {}
        for idx, step in enumerate(proof_steps[1:]):
            state_string, reward, done, _, proof_level = env.step_to_top_level_state(
                step, f"temp_{idx}", f"temp_{idx+1}", return_proof_level=True)
            
            # retry with more time
            if state_string.strip().startswith("Step error: Timeout after"):
                state_string, reward, done, _, proof_level = env.step_to_top_level_state(
                step, f"temp_{idx}", f"temp_{idx+1}", return_proof_level=True, tactic_time=500000)
            if len(state_string) > 0 and not state_string.strip().startswith("proof"):
                return (step, state_string)
        return None, None

    def extract_single_level_proof2(self, transitions):
        # initialize with the action_label
        action_label_level = transitions[0][2]
        accept_levels = [action_label_level, action_label_level + 1]
        current_level_proof = [transitions[0][1]]
        sorry_flag = False
        sorry_data = []
        temp_state_stack = []

        # -- step 1: deal with special action label, which increases more than 1 level
        # subgoal focus the goal into the first subgoal in apply mode, and increases the proof level by 3
        if transitions[0][1].strip().startswith("subgoal"):
            accept_levels.extend([action_label_level + 2, action_label_level + 3])

        # -- step 2: deal with sorry extraction
        for idx, transition in enumerate(transitions[1:]):
            state, action, proof_level, hammer_results = transition

            if proof_level in accept_levels:
                current_level_proof.append(action)
                sorry_flag = False

                while len(temp_state_stack) > 0 and proof_level <= temp_state_stack[-1][1]:
                    _, l = temp_state_stack.pop()
                    accept_levels = accept_levels[:accept_levels.index(l)+1]

                if action.strip().startswith("proof"):
                    accept_levels.append(max(accept_levels) + 1)

                if action.strip().startswith("qed"):
                    accept_levels.pop()

                if action.strip().startswith("{"):
                    accept_levels.append(max(accept_levels) + 1)
                    accept_levels.append(max(accept_levels) + 1)

                if action.strip().startswith("}"):
                    accept_levels.pop()
                    accept_levels.pop()
                
                if action.strip().startswith("guess"):
                    accept_levels.append(max(accept_levels) + 1)
                    accept_levels.append(max(accept_levels) + 1)
                    accept_levels.append(max(accept_levels) + 1)
                    temp_state_stack.append((action, proof_level))

            elif proof_level > max(accept_levels) and sorry_flag is False:
                current_level_proof.append("sorry")
                # the action before sorry
                action_label = transitions[idx][1]
                action_label_transitions = self.get_action_label_transition(transitions, action_label)
                sorry_data.append(action_label_transitions)
                sorry_flag = True
            else:
                continue
        return current_level_proof, sorry_data

    def get_action_label_transition(self, transitions, action_label):
        action_label_level = -1
        begin_flag = False
        ret_transition = []
        idx = 0
        transitions.append(("dummy", "dummy", -1, "dummy"))
        while idx < len(transitions):
            transition = transitions[idx]
            state, action, proof_level, hammer_results = transition
            if action == action_label:
                begin_flag = True
                action_label_level = proof_level
                ret_transition.append(transition)
                idx += 1
                continue

            if begin_flag is True:
                if proof_level > action_label_level:
                    ret_transition.append(transition)
                    idx += 1
                elif proof_level <= action_label_level:
                    assert len(ret_transition) > 1
                    break
            else:
                idx += 1

        transitions.pop()
        return ret_transition

    def reconstruct_proof(self, action_labeled_proof, init_action_label, normalize_show_thesis=True):
        proof = action_labeled_proof[init_action_label]
        accessed_action_label = set([init_action_label])
        while True:
            change_flag = False
            for idx, action in enumerate(proof):
                if action in action_labeled_proof and action not in accessed_action_label:
                    accessed_action_label.add(action)
                    assert proof[idx + 1] == "sorry"
                    proof = proof[:idx+1] + action_labeled_proof[action][1:] + proof[idx + 2:]
                    change_flag = True
                    break
            if change_flag is False:
                break
        assert len(accessed_action_label) == len(action_labeled_proof)
        if normalize_show_thesis:
            return self.remove_index_from_action(proof)
        return proof

    def add_index_to_action(self, transitions):
        """
        Add index to each action, which make sure no action are duplicated
        """
        new_transitions = []
        for idx, transition in enumerate(transitions):
            state, action, proof_level, hammer_results = transition
            new_transitions.append([
                    state,
                    f"{action}_{idx}",
                    proof_level,
                    hammer_results
            ])
        return new_transitions

    def remove_index_from_action(self, proof_steps):
        """
        Inverse operation of handle_uniqueness
        """
        normalized_proof = []
        for action in proof_steps:
            if action != "sorry":
                action = "_".join(action.split("_")[:-1])
            normalized_proof.append(action)
        return normalized_proof

    def extract_level_data(self, transitions, env, problem_name):
        self.logger.info(f"extracting " + problem_name.replace('\n',' '))
        # transitions = handle_show_thesis(transitions)
        err_step, err = self.collect_steps(env, [t[1] for t in transitions])
        if err_step is not None:
            self.logger.info(f"Original proof error in step: {err_step} and error: {err}")
        transitions = self.add_index_to_action(transitions)
        transition_queue = Queue()
        transition_queue.put(transitions)
        action_labeled_proof = {}

        while not transition_queue.empty():
            current_transition = transition_queue.get()
            action_label = current_transition[0][1]
            proof, sorry_data = self.extract_single_level_proof2(current_transition)
            assert action_label not in action_labeled_proof, "action label contains duplication!"
            action_labeled_proof[action_label] = proof
            for d in sorry_data:
                transition_queue.put(d)
            # reconstructed_proof = self.reconstruct_proof(action_labeled_proof, f"{problem_name}_0")
            # assert self.verify_the_steps(env, reconstructed_proof)
        final_reconstructed_proof = self.reconstruct_proof(action_labeled_proof, f"{problem_name}_0", normalize_show_thesis=False)
        assert len(final_reconstructed_proof) == len(transitions), f"The final reconstructed proof (length {len(final_reconstructed_proof)}) is not equal to the original proof (length {len(transitions)})!"
        assert all([final_reconstructed_proof[idx] == transitions[idx][1] for idx in range(len(transitions))]), \
            "The final reconstructed proof is not equal to the original proof!\n" + \
            "The final reconstructed proof: \n" + "\n".join(final_reconstructed_proof) + "\n" + \
            "The original proof: \n" + "\n".join([transition[1] for transition in transitions])
        return action_labeled_proof

    def shrink_sorry(self, proof):
        new_proof = []
        for idx, action in enumerate(proof):
            if action == "sorry":
                continue
            if idx < len(proof) - 1 and proof[idx + 1] == "sorry":
                action = action.split("_")
                action, num = "_".join(action[:-1]), action[-1]
                action = f"{action} sorry_{num}"
                new_proof.append(action)
            else:
                new_proof.append(action)
        return new_proof


    def sequentialize_proof(self, env, action_labeled_proof, problem_name, step_size=1, context_type="last_step"):
        proof = action_labeled_proof[f"{problem_name}_0"]
        multilevel_transitions = {}
        
        while True:
            change_flag = False
            proof = self.shrink_sorry(proof)
            transitions = self.get_transitions(env, proof)
            for k, v in transitions.items():
                if k not in multilevel_transitions:
                    multilevel_transitions[k] = v
            for i in reversed(list(range(len(proof)))):
                if "sorry" in proof[i]:
                    action = proof[i]
                    label = "_".join(action.split("_")[:-1])
                    assert label.endswith("sorry"), f"action don't ends with sorry: {label}"
                    label = label[:-6].strip() + "_" + action.split("_")[-1]
                    assert label in action_labeled_proof, f"label don't in dict: {label}"
                    proof = proof[:i] + action_labeled_proof[label] + proof[i+1:]
                    change_flag = True
            if change_flag is False:
                break
        return multilevel_transitions


    def create_level_data(self, problem, env):
        self.prepare_initial_state(env, problem)
        action_labeled_proof = self.extract_level_data(problem["transitions"], env, problem["problem_name"])
        multilevel_proof = self.sequentialize_proof(env, action_labeled_proof, problem["problem_name"])
        env.reset_problem()
        return multilevel_proof


    def extract_a_file_from_params(
        self,
        rank,
        jar_path, 
        isabelle_path, 
        working_directory, 
        theory_file_path,
        saving_path, 
        error_path,
        sub_saving_path,
        sub_error_path,
        specific_problem="",
        hparms=None,
    ):
        env = None
        port = 8000 + rank
        server_subprocess_id = None
        retry_cnt = 5
        if DEBUG is False and os.path.isfile(saving_path):
            return

        try:
            if not self.debug:
                server_subprocess_id = start_server(jar_path, port, 
                    outputfile=sub_saving_path, errorfile=sub_error_path)
            # Getting the environment
            env = initialise_env(
                port=port,
                isa_path=isabelle_path,
                theory_file_path=theory_file_path,
                working_directory=working_directory,
                logger=self.logger,
            )
        except FunctionTimedOut as e:
            print(f"Rank {rank}: Timeout in initializing: {str(e)}", flush=True)
            self.logger.info("Timeout in extracting data")
            clean_up_isabelle_env(env, server_subprocess_id)
            return True, str(e)
        except IsabelleRuntimeError as e:
            print(f"Rank {rank}: Isabelle runtime failure with error: {str(e)}", flush=True)
            self.logger.info(f"Isabelle runtime failure with error: {str(e)}")
            json.dump({"error": str(e), "stack_trace": traceback.format_exc()}, open(error_path, "w"))
            clean_up_isabelle_env(env, server_subprocess_id)
            return False, str(e)
        except Exception as e:
            print(f"Rank {rank}: Other error in initializing environment and extracting: {e}", flush=True)
            print(f"Rank {rank}: {traceback.format_exc()}", flush=True)
            self.logger.info(f"Rank {rank}: Error in initializing environment and extracting: {e}")
            self.logger.info(f"Rank {rank}: {traceback.format_exc()}")
            clean_up_isabelle_env(env, server_subprocess_id)
            return True, str(e)

        whole_file_string = env.post("PISA extract data")
        assert env is not None, "environment failed"
        if "<\STATESEP>" not in whole_file_string:
            self.logger.info(f"isabelle parse failed with error: {whole_file_string}")
            print(f"Rank {rank}: isabelle parse failed with error:", whole_file_string.replace("\n", " "), flush=True)
            clean_up_isabelle_env(env, server_subprocess_id)
            return True, f"isabelle parse failed with error: {whole_file_string}"
        # Parse the string and dump
        level_data = []
        problems = self.analyse_file_string(whole_file_string)
        try:
            if DEBUG and len(specific_problem) > 0:
                for problem in problems:
                    # when the problem name contains duplication, this might cause trouble
                    if specific_problem in problem["problem_name"]:
                        problem_level_data = self.create_level_data(problem, env)
                        problem["multilevel_proof"] = problem_level_data
            else:
                for problem in problems:
                    problem_level_data = self.create_level_data(problem, env)
                    problem["multilevel_proof"] = problem_level_data 
        except Exception as e:
            print(f"Rank {rank}: Error in multilevel data extracting: {e}", flush=True)
            print(f"Rank {rank}: {traceback.format_exc()}", flush=True)
            self.logger.info(f"Rank {rank}: Error in multilevel data extracting: {e}")
            self.logger.info(f"Rank {rank}: {traceback.format_exc()}")
            clean_up_isabelle_env(env, server_subprocess_id)
            return True, str(e)

        self.logger.info(f"Success in extracting data, putting into: {saving_path}")
        output_data = {
            "hparams": hparms,
            "problems": problems,
        }
        json.dump(output_data, open(saving_path, "w"), indent=4)

        # Clean up
        clean_up_isabelle_env(env, server_subprocess_id)
        return False, None


def run(rank: int, params_path: Union[mp.Queue, str], debug: bool = False, specific_problem=""):
    """
    Extracts the data from a single file.

    :param params_path: Path to the JSON file containing the parameters.
    :param rank: Rank of the process in the subprocess pool.
    :return: None
    """
    logger = logging.getLogger(f'process-{rank}')
    extractor = MultilevelDataExtractor(rank, debug=debug)
    if isinstance(params_path, str):
        print(f"Rank {rank} is extracting {params_path}", flush=True)
        logger.info(f"Rank {rank} is extracting {params_path}")

        params = json.load(open(params_path))
        extractor.extract_a_file_from_params(
            rank=rank,
            jar_path = params["jar_path"],
            isabelle_path = params["isabelle_path"],
            working_directory = params["working_directory"],
            theory_file_path = params["theory_file_path"],
            saving_path = params["saving_path"],
            error_path = params["error_path"],
            sub_saving_path = params["sub_saving_path"],
            sub_error_path = params["sub_error_path"],
            specific_problem = specific_problem,
            hparms=params,
        )
        return None

    n_error = 0
    while not param_paths.empty():
        params_path, left_retry = param_paths.get()

        # Load the parameters
        params = json.load(open(params_path))   
        left_retry -= 1
        print(f"Rank {rank} is extracting {params_path} on try {N_RETRY - left_retry}", flush=True)
        logger.info(f"Rank {rank} is extracting {params_path} on try {N_RETRY - left_retry}")
        require_sleep, error = extractor.extract_a_file_from_params(
            rank=rank,
            jar_path = params["jar_path"],
            isabelle_path = params["isabelle_path"],
            working_directory = params["working_directory"],
            theory_file_path = params["theory_file_path"],
            saving_path = params["saving_path"],
            error_path = params["error_path"],
            sub_saving_path = params["sub_saving_path"],
            sub_error_path = params["sub_error_path"],
            hparms=params,
        )
        if error is None:
            print(f"Rank {rank} extract complete! {params_path}!", flush=True)
            logger.info(f"Rank {rank} extract complete! {params_path}!")
            # if error is not None:
            #     error = error.replace("\n", " ")
            #     print(f"Rank {rank} extract {params_path} complete but failed on try {N_RETRY - left_retry} with error: {error}!", flush=True)
        else:
            print(f"Rank {rank} extract {params_path} failed on try {N_RETRY - left_retry}!", flush=True)
            logger.info(f"Rank {rank} extract {params_path} failed on try {N_RETRY - left_retry}!")

            if require_sleep:
                if left_retry > 0:
                    param_paths.put((params_path, left_retry))
                n_error = (n_error + 1) % 6
                sleep_time = 5**(n_error)
                print(f"Rank {rank} extract {params_path} failed on try {N_RETRY - left_retry}!, will sleep for {sleep_time} seconds", flush=True)
                logger.info(f"Rank {rank} extract {params_path} failed on try {N_RETRY - left_retry}!, will sleep for {sleep_time} seconds")
                time.sleep(sleep_time)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Extracting multilevel data from theory files.')
    parser.add_argument('--jar-path', '-jp', type=str, help='Path to the pisa jar file', 
                        default="/data2/wanghaiming/project/pisa_data/Portal-to-ISAbelle/target/scala-2.13/PISA-assembly-0.1.jar")
    parser.add_argument('--isabelle-path', '-ip', type=str, help='Path to the Isabelle installation', 
                        default="/home/wanghaiming_p21/Isabelle2022")
    parser.add_argument('--theorem_source', '-efd', type=str, help='Where the source .thy files are', 
                        default="/data2/wanghaiming/Isabelle2022/src/HOL/")
    parser.add_argument('--saving-directory', '-sd', type=str, help='Where to save the output_thy.json file', 
                        default="/data2/wanghaiming/project/pisa_data/Portal-to-ISAbelle/HOL_extractions")
    parser.add_argument('--number-of-prover-processes', '-npp', type=int, help='Number of prover processes',
                        default=1)
    args = parser.parse_args()

    # set up the logger
    start_time = datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y%m%d_%H%M%S")
    os.makedirs(f'logs/extract_multilevel_data/{start_time}_logs', exist_ok=True)
    for rank in range(args.number_of_prover_processes):
        logger = logging.getLogger(f'process-{rank}')
        handler = logging.FileHandler(
            f"logs/extract_multilevel_data/{start_time}_logs/rank_{rank}.log")
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
        if DEBUG:
            logger.addHandler(logging.StreamHandler())

    # Create the saving directories
    os.makedirs(args.saving_directory, exist_ok=True)
    output_data_path = os.path.join(args.saving_directory, "data")
    output_param_path = os.path.join(args.saving_directory, "params")
    os.makedirs(output_data_path, exist_ok=True)
    os.makedirs(output_param_path, exist_ok=True)

    files = glob.glob(args.theorem_source.rstrip("/") + '/**/*.thy', recursive=True)
    param_paths = mp.Queue()

    for file_path in tqdm(files):
        identifier = file_path.replace("/", "_")

        # -- step 1: figure out the working directory
        if "thys" in file_path:
            bits = file_path.split("/")
            thys_index = bits.index("thys")
            # working_directory = "/".join(bits[:thys_index + 2])
            working_directory = "/".join(bits[:-1])
        elif "src/HOL" in file_path:
            bits = file_path.split("/")
            hol_index = bits.index("HOL")
            bits = bits[:-1]
            bits = bits[:hol_index + 2]
            working_directory = "/".join(bits)
        else:
            assert False, f"The file path don't contain `thys` or `src/HOL`: {file_path}"

        # -- step 2: setup the saving path
        saving_path = f"{output_data_path}/{identifier}_output.json"
        error_path = f"{output_data_path}/{identifier}_error.json"
        sub_saving_path = f"{output_data_path}/{identifier}_subout.json"
        sub_error_path = f"{output_data_path}/{identifier}_suberr.json"

        if os.path.exists(saving_path) or os.path.exists(error_path):
            continue
        
        params = {
            "jar_path": args.jar_path,
            "isabelle_path": args.isabelle_path,
            "working_directory": working_directory,
            "theory_file_path": file_path,
            "saving_path": saving_path,
            "error_path": error_path,
            "sub_saving_path": sub_saving_path,
            "sub_error_path": sub_error_path
        }
        param_path = os.path.join(output_param_path, f"{identifier}.json")
        json.dump(params, open(param_path, "w"))

        param_paths.put((param_path, N_RETRY))

    print(f"Extracting {param_paths.qsize()} files in total.")
    if DEBUG:
        args.number_of_prover_processes = 1
        run(0, 
            "/data2/wanghaiming/project/pisa_data/Portal-to-ISAbelle/afp_extractions15/params/_data2_wanghaiming_afp-2022-12-06_thys_Independence_CH_CH.thy.json",
            debug=True,
            specific_problem=""
        )
        # exit(0)

    if args.number_of_prover_processes == 1:
        run(0, param_paths, DEBUG)
    else:
        processes = []
        for rank in range(args.number_of_prover_processes):
            p = mp.Process(target=run, args=(rank, param_paths, DEBUG))
            processes.append(p)
            p.start()
        # completing process
        for p in processes:
            p.join()
