"""
Data Extraction Script (data_extract.py)

This script is designed for concurrent processing of Lean-related JSON data files. 
It mutates each theorem item and saves the results to a new JSON file.

Usage:
python data_extract.py --input_file <input_json_path> --output_file <output_json_path>
"""

import logging
import json
import argparse
from typing import List, Dict, Any, Tuple
from lean_verifier import LeanVerifier  # type: ignore

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filename="logs/mutation.log",
    filemode="w",
)

HEAD = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n"
EXTRACT_GOAL = "by (try push_neg); (set_option pp.funBinderTypes true in extract_goal using {name}); sorry"

def preprocess_for_mutation(data: List[Dict[str, Any]], hyp_idx: int) -> Tuple[List[str], List[Dict[str, Any]]]:
    """
    The current version only extracts the formal statement (no full version is included) from the original theorem
    """
    total_data, total_codes = [], []
    for item in data:
        name = item['name']
        normalized_name = name.replace(".", "_")
        imports = item.get("imports", HEAD)
        heads = "import Mutation.Tactics" + "\n" + imports
        formal_statement = item['formal_theorem']
        # formal_proof = item['formal_proof'] 
        thm_name = normalized_name + f"_drop{hyp_idx}"
        thm = formal_statement.replace("sorry", f"by replMutation {hyp_idx}; sorry")
        cmd = heads + "\n\n" + thm
        total_codes.append(cmd)
        total_data.append({
            "name": thm_name,
            "imports": imports,
        })
    return total_codes, total_data

def preprocess_for_extraction(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    total_data, total_hyp_codes, total_state_codes = [], [], []
    for item in data:
        thm_name = item['name']
        imports = item.get("imports", HEAD)
        heads = imports
        # extract the dropped hypothesis
        formal_statement = item["dropped_hypothesis"]
        thm = formal_statement.replace("sorry", EXTRACT_GOAL.format(name=thm_name))
        cmd = heads + "\n\n" + thm
        total_hyp_codes.append(cmd)
        # extract the formal statement
        formal_statement = item['formal_statement']
        thm = formal_statement.replace("sorry", EXTRACT_GOAL.format(name=thm_name))
        cmd = heads + "\n\n" + thm
        total_state_codes.append(cmd)
        # save the data
        total_data.append({
            "name": thm_name,
            "imports": imports,
        })
    return total_data, total_hyp_codes, total_state_codes

def mutate_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    mutated_data = []
    for drop_idx in range(128):
        total_codes, total_data = preprocess_for_mutation(data, drop_idx)
        
        with LeanVerifier() as verifier:
            response = verifier.verify_batch(total_codes, timeout=60, use_tqdm=True)     
        
        success_data, num_failed = [], 0
        for idx, (item, result) in enumerate(zip(total_data, response)):
            thm_info = verifier.parse_mutated_thms(item['name'], result)
            if thm_info is None:
                logging.error(f"Mutation {item['name']} encountered error:")
                logging.error(total_codes[idx])
                logging.error(result)
            elif thm_info == {}:
                # logging.error(f"Remove item {item['name']} because it is not mutated")
                # logging.error(total_codes[idx])
                # logging.error(result)
                num_failed += 1
                # print(f"Remove item {item['name']} because it is not mutated")
            elif thm_info is not None:
                mutated_data.append({
                    "name": item['name'],
                    "imports": item['imports'],
                    "formal_statement": thm_info['mutated_version'],
                    "dropped_hypothesis": thm_info['dropped_hypothesis'],
                    "full_formal_statement": thm_info['full_mutated_version'],
                    "full_dropped_hypothesis": thm_info['full_dropped_hypothesis'],
                })
                success_data.append(data[idx])
            else:
                raise ValueError(f"Unknown error for {item['name']}")
                    
        if len(success_data) == 0:
            print(f"No mutated data generated for idx {drop_idx}")
            break
        else:
            data = success_data
            print(f"Finish mutation for idx {drop_idx}, generate {len(success_data)} mutated data")
    return mutated_data

def extract_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    total_data, total_hyp_codes, total_state_codes = preprocess_for_extraction(data)
    with LeanVerifier() as verifier:
        hyp_response = verifier.verify_batch(total_hyp_codes, timeout=60, use_tqdm=True)     
        state_response = verifier.verify_batch(total_state_codes, timeout=60, use_tqdm=True)     
    
    extracted_data, num_failed = [], 0
    for idx, (item, hyp_result, state_result) in enumerate(zip(total_data, hyp_response, state_response)):
        state_info = verifier.parse_extracted_thms(item['name'], state_result)
        hyp_info = verifier.parse_extracted_thms(item['name'], hyp_result)
        if state_info is None or hyp_info is None:
            logging.error(f"Extraction {item['name']} encountered error:")
            logging.error(total_hyp_codes[idx])
            logging.error(hyp_result)
            logging.error(total_state_codes[idx])
            logging.error(state_result)
        elif state_info == {} or hyp_info == {}:
            num_failed += 1
        else:
            extracted_data.append({
                "name": item['name'],
                "imports": item['imports'],
                "formal_statement": state_info['formal_theorem'],
                "dropped_hypothesis": hyp_info['formal_theorem'],
            })
    print(f"Extracted {len(extracted_data)} theorems, and failed for {num_failed} theorems")
    return extracted_data

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--output_file", type=str, required=True)
    parser.add_argument("--num_theorems", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=10000)
    args = parser.parse_args()
    
    with open(args.input_file, 'r') as f:
        data = json.load(f)
    args.num_theorems = len(data) if args.num_theorems == -1 else args.num_theorems
    data = data[:args.num_theorems]
    
    mutated_data = []
    for i in range(0, len(data), args.batch_size):
        print(f"Mutating batch {i//args.batch_size + 1} of {len(data) // args.batch_size + 1}")
        mutated_data.extend(mutate_data(data[i:i+args.batch_size]))
        
    extracted_data = []
    for i in range(0, len(mutated_data), args.batch_size):
        print(f"Extracting batch {i//args.batch_size + 1} of {len(mutated_data) // args.batch_size + 1}")
        extracted_data.extend(extract_data(mutated_data[i:i+args.batch_size]))
                
    print("Total extracted data:", len(extracted_data))
    with open(args.output_file, 'w') as f:
        json.dump(extracted_data, f, indent=2, ensure_ascii=False)
        
if __name__ == '__main__':
    main()