#!/usr/bin/env python3

"""
Script to process and save mutation data from Lean theorem prover.

Usage: lake exe mutation Mutation.example | python3 scripts/data_create.py --dir output_file
"""

import sys
import json
import os
import argparse
from typing import List, Dict, Any, Tuple
from lean_verifier import LeanVerifier  # type: ignore

HEAD = "import Mathlib"
EXTRACT_GOAL = "by (try push_neg); (set_option pp.funBinderTypes true in extract_goal using {name}); sorry"

def parse_json(json_str: str) -> str:
    """
    Fix trailing commas in JSON string to make it valid for json.loads.
    
    Args:
        json_str: Input JSON string that may contain trailing commas
        
    Returns:
        A valid JSON string with trailing commas removed from objects and arrays
    """
    data = []
    try:
        json_str = json_str.strip().strip("[]")
    except json.JSONDecodeError:
        raise ValueError("Invalid JSON string")
    for example in json_str.split("--<<SEP>>--"):
        tmp_dict = {}
        if example.strip("\n") == "":
            continue
        if "name" in example and "original_version" in example:
            content = example.split("name:", 1)[1]
            name_val, content = content.split("original_version:", 1)
            tmp_dict["name"] = name_val.strip()
        else:
            raise ValueError(f"Invalid JSON string: {example}")
        curr_key = "original_version"
        for i in range(128):
            if f"dropped_hypothesis_{i}" not in example or f"mutated_version_{i}" not in example:
                tmp_dict[curr_key] = content.strip()
                break
            else:
                orig_val, content = content.split(f"dropped_hypothesis_{i}:", 1)
                tmp_dict[curr_key] = orig_val.strip()
                drop_val, content = content.split(f"mutated_version_{i}:", 1)
                tmp_dict[f"dropped_hypothesis_{i}"] = drop_val.strip()
                curr_key = f"mutated_version_{i}"
        data.append(tmp_dict)
    return data

def remove_duplicate_name(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Remove entries with duplicate names from the data list.
    
    Args:
        data: List of dictionaries containing mutation data
        
    Returns:
        List with duplicate entries removed, keeping only the first occurrence
    """
    name_set = set()
    unique_data = []
    for entry in data:
        if entry['name'] not in name_set:
            name_set.add(entry['name'])
            unique_data.append(entry)
    return unique_data

def preprocess(data: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
    """
    Preprocess the data to extract the code and the data.

    Args:
        data: List of dictionaries containing mutation data

    Returns:
        total_codes: List of strings containing the code
        total_data: List of dictionaries containing the data
    """
    total_data, total_codes = [], []
    for item in data:
        name = item['name']
        normalized_name = name.replace(".", "_")
        for i in range(128):
            if f"dropped_hypothesis_{i}" not in item or f"mutated_version_{i}" not in item:
                break
            else:
                thm_name = normalized_name+f"_drop{i}"
                thm = item[f"mutated_version_{i}"].replace("sorry", EXTRACT_GOAL.format(name=thm_name))
                hyp_name = thm_name+"_hyp"
                hyp = item[f"dropped_hypothesis_{i}"].replace("sorry", EXTRACT_GOAL.format(name=hyp_name))
                cmd = HEAD + "\n\n" + hyp + "\n\n" + thm
                total_codes.append(cmd)
                total_data.append({
                    "name": name,
                    "imports": HEAD
                })
    return total_codes, total_data

def main() -> None:
    """Main function to process and save mutation data."""
    # Set up command line argument parsing
    parser = argparse.ArgumentParser(description='Process and save mutation data from Lean theorem prover')
    parser.add_argument('--dir', required=True, help='Output file name (without extension)')
    args = parser.parse_args()
    
    # Create datasets directory if it doesn't exist
    os.makedirs('./datasets', exist_ok=True)
    print("start processing...")
    try:
        # Read input data from stdin
        data_str = sys.stdin.read()
        print("data load finished!")
        
        # Fix JSON formatting and parse
        data = parse_json(data_str)
        
        # Remove duplicates if any
        data = remove_duplicate_name(data)
        
        total_codes, total_data = preprocess(data)
        cleaned_data = []

        with LeanVerifier() as verifier:
            response = verifier.verify_batch(total_codes, timeout=60)
        
        for item, result in zip(total_data, response):
            thm, hyp = verifier.parse_thms(item['name'], item['name'], result)
            if thm is not None and hyp is not None:
                cleaned_data.append({
                    "name": thm,
                    "imports": item['imports'],
                    "formal_statement": thm,
                    "dropped_hypothesis": hyp,
                })
        
        with open(args.dir, 'w') as f:
            json.dump(cleaned_data, f, indent=2, ensure_ascii=False)
        
    except json.JSONDecodeError as e:
        print(f"Error: Invalid JSON format - {str(e)}", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error: {str(e)}", file=sys.stderr)
        sys.exit(1)

if __name__ == '__main__':
    main()