"""
step3_check_data_label.py
This script is used to evaluate formal proofs using Lean server and generate data labels.
"""

import sys
import argparse
import json
import os
import time
from tqdm import tqdm
from typing import List, Dict, Any, Optional, Tuple
import logging

from lean_verifier import LeanVerifier  # type: ignore
from train_utils import load_large_data, save_json
from logger import setup_logger

logger = logging.getLogger(__name__)

def preprocess_data(data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Preprocess the data: extract state and hypothesis formal proofs.
    """
    state_formal_data, hyp_formal_data = [], []
    for problem in data:
        res = problem["results"]
        for (idx, t) in enumerate(res):
            state_formal_proof = t.get("state_formal_proof", [])
            hyp_formal_proof = t.get("hyp_formal_proof", [])
                        
            state_formal_data.append({
                "problem_id": problem["problem_id"],
                "counterex_idx": idx,
                "full_proof": state_formal_proof,
            })
            hyp_formal_data.append({
                "problem_id": problem["problem_id"],
                "counterex_idx": idx,
                "full_proof": hyp_formal_proof,
            })
    
    logger.info(f"Preprocessed {len(state_formal_data)} state formal problems and {len(hyp_formal_data)} hyp formal problems.")
    return state_formal_data, hyp_formal_data

def flatten_for_evaluation(data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Dict[int, tuple]]:
    """
    Flatten the data for evaluation input.
    """
    codes = []
    problem_id_map = {}
    total_idx = 0
    for problem in data:
        problem_id = problem["problem_id"]
        counterex_idx = problem["counterex_idx"]
        proofs = problem["full_proof"]
        for i, proof in enumerate(proofs):
            codes.append(proof)
            problem_id_map.update({total_idx: (problem_id, counterex_idx, i)})
            total_idx += 1
    return codes, problem_id_map

def check_proofs(
    codes: List[Dict[str, Any]],
    start_idx: int = 0,
) -> Tuple[List[Dict[str, Any]], Optional[int]]:
    """
    Evaluate proofs in batches using Lean server.
    """
    if not codes or start_idx >= len(codes):
        logger.warning("No problems to evaluate or start_idx out of range.")
        return [], None
    
    total_results = []
    with LeanVerifier() as verifier:
        try:
            response = verifier.verify_batch(codes, timeout=60, use_tqdm=True)
            batch_results = verifier.parse_results(response)
            assert len(codes) == len(batch_results)
            total_results.extend(batch_results)
            
            # Check error rate
            error_rate = sum(t.get("has_error", True) for t in batch_results) / len(batch_results)
            if error_rate >= 1.0:
                logger.warning(f"Error rate == 1.0")
                logger.warning(f"Example proof: {codes[0]} \n response: {response[0].get('response', '')}")
            
            passed_count = sum(1 for r in batch_results if r['is_valid_no_sorry'])
            logger.info(f"Batch pass: {passed_count}/{len(codes)}")
            
        except Exception as e:
            logger.exception(f"Proofs evaluation failed: {e}")
            total_results.extend([{"is_valid_no_sorry": False}] * len(codes))
    
    return total_results, None

def merge_evaluation_results(
    total_results: List[Dict[str, Any]],
    problem_id_map: Dict[int, tuple],
    data_list: List[Dict[str, Any]],
    prefix: str = "state",
) -> List[Dict[str, Any]]:
    """
    Merge evaluation results back into the original data structure.
    """
    
    # build a dict for problem_id to data
    data_dict = {data["problem_id"]: data for data in data_list if len(data["results"]) > 0}
    
    # merge results back to the original data
    for idx, result in enumerate(total_results):
        problem_id, counterex_idx, proof_idx = problem_id_map[idx]
        
        # directly lookup the data by problem_id
        if problem_id not in data_dict:
            continue
            
        data = data_dict[problem_id]
        formal_proofs = data["results"][counterex_idx][f"{prefix}_formal_proof"]
        length = len(formal_proofs)
        passed_list = data["results"][counterex_idx].get(f"{prefix}_passed", [False] * length)
        passed_list[proof_idx] = result["is_valid_no_sorry"]
        data["results"][counterex_idx][f"{prefix}_passed"] = passed_list
        data["results"][counterex_idx][f"{prefix}_passed_rate"] = sum(passed_list) / length
    
    return data_list

def evaluate_generation(
    data_list: Dict[str, Any],
    key: str,
    prefix: str,
) -> List[Dict[str, Any]]:
    """
    Process the data in segments and evaluate formal proofs.
    key: 'state_formal_data' or 'hyp_formal_data'
    prefix: 'state' or 'hyp'
    """
    formal_data = data_list[key]
    total_size = len(formal_data)
        
    all_results = data_list["original_data"]
        
    # Flatten data for evaluation
    codes, problem_id_map = flatten_for_evaluation(formal_data)
        
    # Evaluate proofs
    total_results = []
    eval_start_idx = 0
    while True:
        results, next_idx = check_proofs(codes, start_idx=eval_start_idx)
        total_results.extend(results)
        if next_idx is None:
            break
        logger.warning(f"Server restarted, continue from idx {next_idx}")
        time.sleep(5)
    
    all_results = merge_evaluation_results(
        total_results, problem_id_map, all_results, prefix=prefix
    )
        
        # Save intermediate results
        # save_json(all_results, args.output_file)
        # logger.info(f"Segment {segment} results saved to {args.output_file}")
    
    return all_results

def main():
    parser = argparse.ArgumentParser(description="Evaluate Lean proofs and generate data labels.")
    parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSON file")
    parser.add_argument("--batch_size", type=int, default=4096, help="Batch size for evaluation")
    parser.add_argument("--total_segments", type=int, default=1, help="Total number of segments to split data into")
    parser.add_argument("--min_memory_gb", type=int, default=20, help="Minimum available memory in GB")
    args = parser.parse_args() # type: ignore
    
    # setup logger
    log_path = "logs/step3_check_data_label.log"
    logger = setup_logger(log_path)

    # Load and preprocess data
    data_list = load_large_data(args.input_file)
    state_formal_data, hyp_formal_data = preprocess_data(data_list)
    
    # Prepare data structure for processing
    processed_data = {
        "original_data": data_list,
        "state_formal_data": state_formal_data,
        "hyp_formal_data": hyp_formal_data,
    }
    
    # Generate output file name
    output_file = args.input_file.replace(
        "step2_generate_formal_proof", "step3_check_data_label"
    ).replace(".json", "_check.json")
    args.output_file = output_file
    
    # Step 1: Evaluate state formal proofs
    logger.info("Start evaluating state formal proofs...")
    processed_data["original_data"] = evaluate_generation(
        processed_data, args, key="state_formal_data", prefix="state"
    )
    
    # Step 2: Evaluate hypothesis formal proofs
    logger.info("Start evaluating hypothesis formal proofs...")
    processed_data["original_data"] = evaluate_generation(
        processed_data, args, key="hyp_formal_data", prefix="hyp"
    )
    
    # Save final results
    save_json(processed_data["original_data"], args.output_file)
    logger.info(f"All evaluation results saved to {args.output_file}")
    logger.info("Example output:")
    logger.info(processed_data["original_data"][0])

if __name__ == "__main__":
    main()
