"""
step2_generate_formal_proof.py
This script is used to generate formal proofs for mathematical statements using configurable provers.
"""

import argparse
import json
import os
import logging
from typing import Any, Dict, List, Tuple, Callable

from datasets import load_dataset
from train_utils import load_data, load_large_data, save_json, get_prover
from logger import setup_logger

logger = logging.getLogger(__name__)

def preprocess_data(data_list, args) -> List[Dict]:
    """
    Preprocess the data: add header, fix formal statements, convert to list.
    """
    if args.default_header:
        header_str = (
            "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators "
            "Real Nat Topology Rat\n\n"
        )
    else:
        header_str = None
    for d in data_list:
        d["header"] = header_str if header_str else d["imports"]
        d.pop("imports") # remove imports column
        d["formal_statement"] = d.get("formal_statement", "").replace("sorry", "by\n")
        d["dropped_hypothesis"] = d.get("dropped_hypothesis", "").replace("sorry", "by\n")
    return data_list

def flatten_for_prover(data_segment: List[Dict], key: str, start_idx: int) -> List[Dict]:
    """
    Flatten the data segment for prover input.
    key: 'formal_statement' or 'dropped_hypothesis'
    """
    flatten_list = []
    for i, d in enumerate(data_segment):
        for j, res in enumerate(d["results"]):
            if key not in d or d[key] == "":
                flatten_list.append({})
            else:
                flatten_list.append({
                    "problem_id": start_idx + i,
                    "header": d["header"],
                    "formal_statement": d[key],
                    "counter_example": res["counter_example"],
                })
    return flatten_list

def update_results_with_proof(
    data_segment: List[Dict],
    flatten_results: List[Dict],
    prefix: str = "state"
) -> List[Dict]:
    """
    Update the original data structure with proof results.
    prefix: 'state' or 'hyp'
    """
    idx = 0
    for d in data_segment:
        for j in range(len(d["results"])):
            prover_input = flatten_results[idx].get("prover_input", "None")
            prover_outputs = flatten_results[idx].get("prover_outputs", [])
            formal_proof = flatten_results[idx].get("formal_proof", [])
            d["results"][j].update({
                f"{prefix}_prover_input": prover_input,
                f"{prefix}_prover_outputs": prover_outputs,
                f"{prefix}_formal_proof": formal_proof,
            })
            idx += 1
    return data_segment

def formal_proof_generate(
    data_list: List[Dict],
    key: str,
    prefix: str,
    prover: Callable
) -> List[Dict]:
    """
    Process the data in segments and generate formal proofs.
    key: 'formal_statement' or 'dropped_hypothesis'
    prefix: 'state' or 'hyp'
    """
    total_size = len(data_list)
    all_results = []
    flatten_data = flatten_for_prover(data_list, key, 0)
    flatten_results = prover(flatten_data)
    data_list = update_results_with_proof(data_list, flatten_results, prefix)
    all_results.extend(data_list)
    # save_json(all_results, args.output_file)
    # logger.info(f"Formal proof results saved to {args.output_file}")
    return all_results

def main():
    parser = argparse.ArgumentParser(description="Batch generate proofs using whole-generation prover")
    parser.add_argument("--prover_name", type=str, default="goedel", help="Prover to use")
    parser.add_argument("--prover_path", type=str, default="", help="Path to the prover")
    parser.add_argument("--default_header", type=int, choices=[0, 1], default=1, help="Whether to use default header")
    parser.add_argument("--prover_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--gpu", type=int, default=1, help="Number of GPUs to use")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter")
    parser.add_argument("--input_file", type=str, default="save/step1_solve_counter_example_train_data_10_erdos_5.json", help="Path to the input dataset")
    parser.add_argument("--save_dir", type=str, default="save/", help="Path to the output file")
    args = parser.parse_args() # type: ignore

    # setup logger
    log_path = "logs/step2_generate_formal_proof.log"
    setup_logger(log_path)

    data_list = load_large_data(args.input_file)
    logger.info(f"Loaded dataset, total {len(data_list)} samples.")
    data_list = preprocess_data(data_list, args)
    logger.info(f"Data size after preprocessing: {len(data_list)}")

    prover = get_prover(args.prover_name, args)
    args.output_file = args.input_file.replace(
        "step1_solve_counter_example", "step2_generate_formal_proof"
    ).replace(".json", f"_{args.prover_name}_{args.prover_k}.json")

    # Step 1: Generate proofs for original statements
    data_list = formal_proof_generate(data_list, args, key="formal_statement", prefix="state", prover=prover)
    # Step 2: Generate proofs for dropped hypothesis statements
    data_list = formal_proof_generate(data_list, args, key="dropped_hypothesis", prefix="hyp", prover=prover)

    logger.info("All formal proofs generated successfully.")
    logger.info("Example output:")
    logger.info(data_list[0])
    save_json(data_list, args.output_file)
    logger.info(f"Final results saved to {args.output_file}")

if __name__ == "__main__":
    main()
