import argparse
import json
import logging
import random
from tqdm import tqdm
from lean_verifier import LeanVerifier

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filename="logs/style_convert.log",
    filemode="w",
)

def insert_tactic(statement: str, num_steps: int = 5) -> str:
    """
    Two insertion rules:
    1. If the step starts with "have ", insert a tactic `extract_goal` after the first `by`.
    2. For consecutive steps that starts with "tactic_patterns ", wrap them using `to_theorem`.

    Args:
        statement (str): The original statement.
        num_steps (int, optional): The number of consecutive steps to wrap using `to_theorem`. Defaults to 5.

    Returns:
        str: The modified statement.
    """
    steps = statement.split("\n")
    new_statement = []
    for step in steps:
        if step.strip().startswith("--"):  # ignore comments
            continue
        # Rule 1: Insert `extract_goal` after the first `by` in `have` statements.
        new_step = step.replace("have", "have_with_extraction")
        new_statement.append(new_step)

    statement = "\n".join(new_statement)
    return statement

def convert_style(data: list[dict]) -> list[dict]:
    total_theorems = []
    for item in data:
        theory = (
            "import Mutation.Tactics"
            + "\n"
            + item["imports"]
            + "\n"
            + item["formal_statement"]
            + "\n"
            + item["formal_proof"]
        )
        extracted_theory = insert_tactic(theory)
        total_theorems.append(extracted_theory)
    
    with LeanVerifier() as verifier:
        results = verifier.verify_batch(total_theorems, timeout=300, use_tqdm=True)
    
    total_results = []
    for res, item in tqdm(zip(results, data), desc="Parse the Style conversion results", total=len(data)):
        thm_info = verifier.parse_convert_thms("_have_extracted_", res)
        if thm_info is None:
            logging.error(f"The following theorem is not converted correctly:")
            logging.error(item['formal_statement'])
            logging.error(item['formal_proof'])
            logging.error(f"The response: {[msg for msg in res['response']['messages'] if msg['severity'] == 'error']}")
        else:
            for idx, thm in enumerate(thm_info):
                thm_name = f"{item['name']}_{idx}" # make the name is consistent with mutated theorem
                total_results.append(
                    {
                        "name": thm_name,
                        "imports": item["imports"],
                        "formal_theorem": thm["formal_theorem"].replace(thm["name"], thm_name),
                        "full_formal_theorem": thm["full_formal_theorem"].replace(thm["name"], thm_name),
                        "formal_proof": thm["formal_proof"],
                    }
                )
    logging.info(f"Style conversion finished. Total theorems: {len(total_results)}")
    return total_results

def check_proof(data: list[dict]) -> list[dict]:
    """
    Check if the proof is valid.
    """
    total_theorems = []
    for item in data:
        formal_theorem = item["imports"] + "\n" + item["formal_theorem"].replace("sorry", "by" + "\n" + item["formal_proof"])
        total_theorems.append(formal_theorem)

    with LeanVerifier() as verifier:
        responses = verifier.verify_batch(total_theorems, timeout=60, use_tqdm=True)
        results = verifier.parse_results(responses)
        for res, response in zip(results, responses):
            res['response'] = response

    valid_thms = []
    for item, res in zip(data, results):
        if res["has_error"] == False:
            item["use_full"] = False
            item["full_formal_theorem"] = ""
            valid_thms.append(item)
        else:
            logging.warning(f"Warning: Theorem {item.get('name', 'unknown')} failed both validations")
            logging.warning(f"The theorem: {item['formal_theorem']}")
            logging.warning(f"The proof: {item['formal_proof']}")
            logging.warning(f"The response: {res['response']}")
    
    return valid_thms

def check_full_proof(data: list[dict]) -> list[dict]:
    """
    Check if the formal theorem or full formal theorem is valid.
    """
    total_theorems = []
    for item in data:
        formal_theorem = item["imports"] + "\n" + item["formal_theorem"].replace("sorry", "by" + "\n" + item["formal_proof"])
        total_theorems.append(formal_theorem)
        
    with LeanVerifier() as verifier:
        responses = verifier.verify_batch(total_theorems, timeout=60, use_tqdm=True)
        results = verifier.parse_results(responses)
        for res, response in zip(results, responses):
            res['response'] = response

    total_full_theorems = []
    for item in data:
        formal_theorem = item["imports"] + "\n" + item["full_formal_theorem"].replace("sorry", "by" + "\n" + item["formal_proof"])
        total_full_theorems.append(formal_theorem)
        
    with LeanVerifier() as verifier:
        full_responses = verifier.verify_batch(total_full_theorems, timeout=60, use_tqdm=True)
        full_results = verifier.parse_results(full_responses)
        for res, response in zip(full_results, full_responses):
            res['response'] = response
        
    valid_thms = []
    for item, res, full_res in zip(data, results, full_results):
        if res["has_error"] == False:
            item["use_full"] = False
            valid_thms.append(item)
        elif full_res["has_error"] == False:
            item["use_full"] = True
            valid_thms.append(item)
        else:
            logging.warning(f"Warning: Theorem {item.get('name', 'unknown')} failed both validations")
            logging.warning(f"The theorem: {item['formal_theorem']}")
            logging.warning(f"The proof: {item['formal_proof']}")
            logging.warning(f"The response: {res['response']}")
            logging.warning(f"The response for full statement: {full_res['response']}")
    
    return valid_thms
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True, help="Input file path")
    parser.add_argument("--output_file", type=str, required=True, help="Output file path")
    parser.add_argument("--num_theorems", type=int, default=10, help="Number of theorems to convert")
    parser.add_argument("--batch_size", type=int, default=10000, help="Number of batch processed")
    parser.add_argument("--use_full", type=bool, default=False, help="Use the full formal theorem")
    args = parser.parse_args()

    with open(args.input_file, "r") as f:
        data = json.load(f)    
    args.num_theorems = len(data) if args.num_theorems == -1 else args.num_theorems
    data = data[:args.num_theorems]

    converted_data = []
    # Partition data into num_batches batches
    for i in range(0, args.num_theorems, args.batch_size):
        print(f"Converting batch {i//args.batch_size + 1} of {args.num_theorems // args.batch_size + 1}")
        converted_data.extend(convert_style(data[i:i+args.batch_size]))
    print(f"Total converted theorems: {len(converted_data)}")
    
    # Check the proof is valid
    check_func = check_full_proof if args.use_full else check_proof
    print(f"Checking the proof is valid with {check_func.__name__}")
    valid_data = []
    for i in range(0, len(converted_data), args.batch_size):
        print(f"Checking batch {i//args.batch_size + 1} of {len(converted_data) // args.batch_size + 1}")
        valid_data.extend(check_func(converted_data[i:i+args.batch_size]))
    print(f"Total valid theorems: {len(valid_data)}")

    with open(args.output_file, "w") as f:
        json.dump(valid_data, f, indent=2, ensure_ascii=False)


if __name__ == "__main__":
    main()
