import argparse
import json
import logging
from typing import Any, Dict, List, Callable
from multiprocessing import Process, Manager

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np

from datasets import load_dataset, Dataset
from train_utils import load_data, load_large_data, save_json, get_solver, get_prover
from eval_utils import print_matrix
from src.logger import setup_logger

from src.step1_solve_counter_example import counter_example_generate
from src.step1_solve_counter_example import preprocess_data as preprocess_data_step1
from src.step2_generate_formal_proof import formal_proof_generate
from src.step2_generate_formal_proof import preprocess_data as preprocess_data_step2
from src.step3_check_data_label import evaluate_generation
from src.step3_check_data_label import preprocess_data as preprocess_data_step3

def execute_counter_example_generate(data_list: List[Dict[str, Any]], args: argparse.Namespace) -> List[Dict[str, Any]]:
    logging.info(f"Start generating counter-examples...")
    solver = get_solver(args.solver_name, args)
    if not args.use_counterexample:
        data_list = counter_example_generate(solver, data_list)
        num_counterexamples = []
        for item in data_list:
            num_counterexamples.append(len(item["results"]))
        logging.info(f"Total number of valid counter-examples: {sum(num_counterexamples)}")
    logging.info("All counter-examples generated successfully.")
    logging.info("Example output:")
    logging.info(data_list[0])
    del solver
    return data_list

def execute_formal_proof_generate(data_list: List[Dict[str, Any]], args: argparse.Namespace):
    logging.info(f"Start generating formal proofs...")
    prover = get_prover(args.prover_name, args)
    data_list = formal_proof_generate(data_list, key="formal_statement", prefix="state", prover=prover)
    if args.check_hypothesis:
        logging.info(f"Start generating dropped hypotheses...")
        data_list = formal_proof_generate(data_list, key="dropped_hypothesis", prefix="hyp", prover=prover)
    num_formal_proofs, num_dropped_hypotheses = [], []
    for item in data_list:
        num_formal_proofs.append(sum([len(res.get("state_formal_proof", [])) for res in item["results"]]))
        num_dropped_hypotheses.append(sum([len(res.get("hyp_formal_proof", [])) for res in item["results"]]))
    logging.info(f"Total number of valid formal proofs: {sum(num_formal_proofs)}")
    if args.check_hypothesis:
        logging.info(f"Total number of valid dropped hypotheses: {sum(num_dropped_hypotheses)}")
    logging.info("All formal proofs generated successfully.")
    logging.info("Example output:")
    logging.info(data_list[0])
    del prover
    return data_list

def execute_check_data_label(processed_data: Dict[str, Any], args: argparse.Namespace):
    logging.info(f"Start checking data label...")
    processed_data["original_data"] = evaluate_generation(processed_data, key="state_formal_data", prefix="state")
    if args.check_hypothesis:
        logging.info(f"Start checking dropped hypotheses...")
        processed_data["original_data"] = evaluate_generation(processed_data, key="hyp_formal_data", prefix="hyp")
    logging.info("All evaluation results generated successfully.")
    return processed_data["original_data"]

def process_step1_and_step2(batch_data, args):
    # step1
    batch_data = execute_counter_example_generate(batch_data, args)
    # step2
    batch_data = preprocess_data_step2(batch_data, args)
    batch_data = execute_formal_proof_generate(batch_data, args)
    return batch_data

def process_step3(batch_data, args):
    state_formal_data, hyp_formal_data = preprocess_data_step3(batch_data)
    processed_data = {
        "original_data": batch_data,
        "state_formal_data": state_formal_data,
        "hyp_formal_data": hyp_formal_data,
    }
    return execute_check_data_label(processed_data, args)

def process_batches_pipeline(data_list, args):
    """
    Process the data in batches using a pipeline of GPU and CPU workers.
    The pipeline is as follows:
    1. GPU worker:
        - The GPU worker is responsible for generating counter-examples and formal proofs.
    2. CPU worker:
        - The CPU worker is responsible for checking the data label.
        - The CPU worker is responsible for saving the results.
    """
    total_results = []
    
    with Manager() as manager:
        gpu_queue = manager.Queue(maxsize=1)
        cpu_queue = manager.Queue(maxsize=1)
        result_queue = manager.Queue()
        
        def gpu_worker(gpu_queue, cpu_queue):
            while True:
                batch_idx, batch_data = gpu_queue.get()
                if batch_data is None:
                    break
                try:
                    logging.info(f"Process step1 and step2 of batch index {batch_idx} and length {len(batch_data)}")
                    result = process_step1_and_step2(batch_data, args)
                    cpu_queue.put((batch_idx, result))
                except Exception as e:
                    logging.error(f"GPU worker error: {e}")
                    raise e
                finally:
                    gpu_queue.task_done()
        
        def cpu_worker(cpu_queue, result_queue):
            while True:
                batch_idx, processed_batch = cpu_queue.get()
                if processed_batch is None:  
                    break
                try:
                    logging.info(f"Process step3 of batch index {batch_idx} and length {len(processed_batch)}")
                    result = process_step3(processed_batch, args)
                    result_queue.put(result)
                except Exception as e:
                    logging.error(f"CPU worker error: {e}")
                    raise e
                finally:
                    cpu_queue.task_done()
        
        gpu_process = Process(target=gpu_worker, args=(gpu_queue, cpu_queue))
        cpu_process = Process(target=cpu_worker, args=(cpu_queue, result_queue))
        gpu_process.start()
        cpu_process.start()
        
        num_batches = len(data_list) // args.batch_size + 1
        logging.info(f"Total {num_batches} batches to process")
        for i in range(num_batches):
            start_idx = i * args.batch_size
            end_idx = (i + 1) * args.batch_size
            batch_data = data_list[start_idx:end_idx]
            
            if batch_data:
                gpu_queue.put((i, batch_data))
                
        
        gpu_queue.join()
        
        gpu_queue.put((None, None))
        cpu_queue.put((None, None))
        
        gpu_process.join()
        cpu_process.join()
        
        while not result_queue.empty():
            total_results.extend(result_queue.get())
    
    return total_results

def report_results(data_list: List[Dict[str, Any]], max_i: int, max_j: int):
    total_hyp_matrix = []
    total_state_matrix = []
    for d in data_list:
        len_res = len(d["results"])
        if len_res == 0:
            continue
        state_matrix = np.zeros((max_i, max_j)) 
        state_passed_matrix = np.array([res.get("state_passed", 0) for res in d["results"]]) 
        if state_matrix[:len_res, :].shape == state_passed_matrix.shape:
            state_matrix[:len_res, :] = state_passed_matrix
        hyp_matrix = np.zeros((max_i, max_j))
        hyp_passed_matrix = np.array([res.get("hyp_passed", 0) for res in d["results"]]) 
        if hyp_matrix[:len_res, :].shape == hyp_passed_matrix.shape:
            hyp_matrix[:len_res, :] = hyp_passed_matrix
        total_state_matrix.append((np.cumsum(np.cumsum(state_matrix, axis=0), axis=1) > 0).astype(float))
        total_hyp_matrix.append((np.cumsum(np.cumsum(hyp_matrix, axis=0), axis=1) > 0).astype(float))
    return total_state_matrix, total_hyp_matrix
    
def main():
    parser = argparse.ArgumentParser(description="Batch generate counter-examples using configurable solver.")
    parser.add_argument("--solver_name", type=str, default="mistral", help="Solver to use")
    parser.add_argument("--prover_name", type=str, default="deepseek_v15_rl", help="Prover to use")
    parser.add_argument("--solver_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.3", help="Path to the solver")
    parser.add_argument("--prover_path", type=str, default="deepseek-ai/DeepSeek-Prover-V1.5-RL", help="Path to the prover")
    parser.add_argument("--num_problems", type=int, default=-1, help="Number of problems to generate")
    parser.add_argument("--solver_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--prover_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--default_header", type=int, default=0, help="Whether to use default header")
    parser.add_argument("--gpu", type=int, default=1, help="Number of GPUs to use")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter")
    parser.add_argument("--batch_size", type=int, default=20000, help="Batch size")
    parser.add_argument("--dataset_path", type=str, default="datasets/test_data.json", help="Path or name of the dataset to use")
    parser.add_argument("--save_dir", type=str, default="save/", help="Path to the output file")
    ##
    parser.add_argument("--use_counterexample", type=bool, default=False, help="Whether to use counterexample")
    parser.add_argument("--check_hypothesis", type=bool, default=False, help="Whether to check hypothesis")
    args = parser.parse_args() # type: ignore
    
    setup_logger(f"{args.save_dir}/eval.log")

    # print model info and generate config
    logging.info(f"Solver model: {args.solver_path}")
    logging.info(f"Prover model: {args.prover_path}")
    logging.info(f"Counterexample generate {args.solver_k} times for each problem")
    logging.info(f"Formalproof generate {args.prover_k} times for each problem")

    ## step1    
    data_name = args.dataset_path.split("/")[-1].split(".")[0]
    if args.dataset_path.endswith(".json") or args.dataset_path.endswith(".jsonl"):
        dataset = load_dataset("json", data_files=args.dataset_path, split="train")
        logging.info(f"Local dataset: {args.dataset_path}")
    else:
        dataset = load_dataset(args.dataset_path, split="train")
        logging.info(f"HuggingFace dataset: {args.dataset_path}")
    data_list = preprocess_data_step1(dataset, args)
    
    data_list = data_list[0:args.num_problems]
    
    # step1 + step2
    total_results = process_batches_pipeline(data_list, args)
    
    # step3
    args.output_file = (
        f"{args.save_dir}/evaluate_results_{data_name}_{args.total_data_size}_"
        f"{args.solver_name}_{args.solver_k}_{args.prover_name}_{args.prover_k}.json"
    )
    save_json(total_results, args.output_file)
    logging.info(f"Final results saved to {args.output_file}")

    # step4 - report results
    state_res_matrix, hyp_res_matrix = report_results(total_results, args.solver_k, args.prover_k)
    state_res_acc_matrix = np.mean(state_res_matrix, axis=0)
    state_res_count_matrix = np.sum(state_res_matrix, axis=0)
    hyp_res_acc_matrix = np.mean(hyp_res_matrix, axis=0)
    hyp_res_count_matrix = np.sum(hyp_res_matrix, axis=0)
    # select matrix style: box, simple, table, heatmap, markdown
    matrix_style = "markdown"
    
    print("\n" + "="*60)
    print("Note: (i, j) means pass@ixj rate")
    print("  i: number of counterexample generations, j: number of formal proof generations")
    print("\n" + "="*60)
    print("📊 STATE RESULT MATRIX")
    print("="*60)
    msg = print_matrix(state_res_acc_matrix, state_res_count_matrix, "State", matrix_style)
    logging.info("The State Result Matrix:")
    logging.info(msg)
    
    print("\n" + "="*60)
    print("📊 HYPOTHESIS RESULT MATRIX")
    print("="*60)
    msg = print_matrix(hyp_res_acc_matrix, hyp_res_count_matrix, "Hyp", matrix_style)
    logging.info("The Hypothesis Result Matrix:")
    logging.info(msg)
    
    print("\n" + "="*60)
    print("📈 SUMMARY STATISTICS")
    print("="*60)
    print(f"Total problems: {len(total_results)}")
    print(f"State matrix count: {np.mean(state_res_count_matrix):.4f}")
    print(f"Hyp matrix count: {np.mean(hyp_res_count_matrix):.4f}")
    print(f"State matrix count std: {np.std(state_res_count_matrix):.4f}")
    print(f"Hyp matrix count std: {np.std(hyp_res_count_matrix):.4f}")
    print(f"State matrix acc: {np.mean(state_res_acc_matrix):.4f}")
    print(f"Hyp matrix acc: {np.mean(hyp_res_acc_matrix):.4f}")
    print(f"State matrix acc std: {np.std(state_res_acc_matrix):.4f}")
    print(f"Hyp matrix acc std: {np.std(hyp_res_acc_matrix):.4f}")
    print("="*60)

if __name__ == "__main__":
    main()

