import argparse
import shutil
from multiprocessing import Process, Manager
import logging
import random
import time
from typing import Any, Dict, List, Callable

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np

from datasets import load_dataset, Dataset
from train_utils import load_data, load_large_data, save_json, get_solver, get_prover
from eval_utils import print_matrix
from src.logger import setup_logger

from src.step1_solve_counter_example import counter_example_generate
from src.step1_solve_counter_example import preprocess_data as preprocess_data_step1
from src.step2_generate_formal_proof import formal_proof_generate
from src.step2_generate_formal_proof import preprocess_data as preprocess_data_step2
from src.step3_check_data_label import evaluate_generation
from src.step3_check_data_label import preprocess_data as preprocess_data_step3
from src.step4_update_model_params import preprocess_data as preprocess_data_step4
from src.step4_update_model_params import update_model_params

def execute_counter_example_generate(data_list: List[Dict[str, Any]], args: argparse.Namespace) -> List[Dict[str, Any]]:
    logging.info(f"Start generating counter-examples...")
    solver = get_solver(args.solver_name, args)
    data_list = counter_example_generate(solver, data_list)
    num_counterexamples = []
    for item in data_list:
        num_counterexamples.append(len(item["results"]))
    logging.info(f"Total number of valid counter-examples: {sum(num_counterexamples)}")
    save_json(data_list, args.generate_output_file)
    logging.info(f"All counter-examples results saved to {args.generate_output_file}")
    # del solver
    time.sleep(1)
    return data_list

def execute_formal_proof_generate(data_list: List[Dict[str, Any]], args: argparse.Namespace):
    logging.info(f"Start generating formal proofs...")
    prover = get_prover(args.prover_name, args)
    data_list = formal_proof_generate(data_list, key="formal_statement", prefix="state", prover=prover)
    if args.check_hypothesis:
        logging.info(f"Start generating dropped hypotheses...")
        data_list = formal_proof_generate(data_list, key="dropped_hypothesis", prefix="hyp", prover=prover)
    num_formal_proofs, num_dropped_hypotheses = [], []
    for item in data_list:
        num_formal_proofs.append(sum([len(res.get("state_formal_proof", [])) for res in item["results"]]))
        num_dropped_hypotheses.append(sum([len(res.get("hyp_formal_proof", [])) for res in item["results"]]))
    logging.info(f"Total number of valid formal proofs: {sum(num_formal_proofs)}")
    logging.info(f"Total number of valid dropped hypotheses: {sum(num_dropped_hypotheses)}")
    save_json(data_list, args.generate_output_file)
    logging.info(f"All formal proofs results saved to {args.generate_output_file}")
    # del prover
    time.sleep(1)
    return data_list

def execute_check_data_label(processed_data: Dict[str, Any], args: argparse.Namespace):
    logging.info(f"Start checking data label...")
    processed_data["original_data"] = evaluate_generation(processed_data, key="state_formal_data", prefix="state")
    if args.check_hypothesis:
        logging.info(f"Start checking dropped hypotheses...")
        processed_data["original_data"] = evaluate_generation(processed_data, key="hyp_formal_data", prefix="hyp")
    num_state_passed, num_hyp_passed = 0.0, 0.0
    for item in processed_data["original_data"]:
        num_state_passed += any([res.get("state_passed_rate", 0.0) > 0.0 for res in item["results"]])    
        num_hyp_passed += any([res.get("hyp_passed_rate", 0.0) > 0.0 for res in item["results"]])
    logging.info(f"Total number of passed solution in checking state formal proofs: {num_state_passed}")
    logging.info(f"Total number of passed solution in checking dropped hypotheses: {num_hyp_passed}")
    save_json(processed_data["original_data"], args.check_output_file)
    logging.info(f"All proof check results saved to {args.check_output_file}")
    return processed_data["original_data"]

def execute_train_model_params(counterexample_data: List[Dict[str, Any]], formal_proof_data: List[Dict[str, Any]]):
    update_model_params(counterexample_data, "solver")
    update_model_params(formal_proof_data, "prover")

def report_results(data_list: List[Dict[str, Any]], max_i: int, max_j: int):
    total_hyp_matrix = []
    total_state_matrix = []
    for d in data_list:
        len_res = len(d["results"])
        if len_res == 0:
            continue
        state_matrix = np.zeros((max_i, max_j)) 
        state_passed_matrix = np.array([res.get("state_passed", 0) for res in d["results"]]) 
        if state_matrix[:len_res, :].shape == state_passed_matrix.shape:
            state_matrix[:len_res, :] = state_passed_matrix
        hyp_matrix = np.zeros((max_i, max_j))
        hyp_passed_matrix = np.array([res.get("hyp_passed", 0) for res in d["results"]]) 
        if hyp_matrix[:len_res, :].shape == hyp_passed_matrix.shape:
            hyp_matrix[:len_res, :] = hyp_passed_matrix
        total_state_matrix.append((np.cumsum(np.cumsum(state_matrix, axis=0), axis=1) > 0).astype(float))
        total_hyp_matrix.append((np.cumsum(np.cumsum(hyp_matrix, axis=0), axis=1) > 0).astype(float))
    return total_state_matrix, total_hyp_matrix

def process_step1_and_step2(batch_data, args):
    # step1
    batch_data = execute_counter_example_generate(batch_data, args)
    # step2
    batch_data = preprocess_data_step2(batch_data, args)
    batch_data = execute_formal_proof_generate(batch_data, args)
    return batch_data

def process_step3(batch_data, args):
    state_formal_data, hyp_formal_data = preprocess_data_step3(batch_data)
    processed_data = {
        "original_data": batch_data,
        "state_formal_data": state_formal_data,
        "hyp_formal_data": hyp_formal_data,
    }
    return execute_check_data_label(processed_data, args)

def process_step4(batch_data, args):
    counterexample_data, formal_proof_data = preprocess_data_step4(batch_data, args.alpha)
    execute_train_model_params(counterexample_data, formal_proof_data)
    return

def process_batches_pipeline(data_list, args):
    """
    Process the data in batches using a pipeline of GPU and CPU workers.
    The pipeline is as follows:
    1. GPU worker:
        - The GPU worker is responsible for generating counter-examples and formal proofs.
    2. CPU worker:
        - The CPU worker is responsible for checking the data label.
        - The CPU worker is responsible for saving the results.
    """
    total_results = []
    
    with Manager() as manager:
        gpu_queue = manager.Queue(maxsize=1)
        cpu_queue = manager.Queue(maxsize=1)
        result_queue = manager.Queue()
        
        def gpu_worker(gpu_queue, cpu_queue):
            while True:
                batch_idx, batch_data = gpu_queue.get()

                if batch_data is None:
                    break
                try:
                    ##### ADD additional unlabeled data
                    tmp_data = load_dataset("json", data_files="/mnt/data1/Anony/Erdos-Prover/datasets/eval/countermath_bench.json", split="train")
                    new_data_list = tmp_data.to_list()
                    random.shuffle(new_data_list)
                    batch_data.extend(new_data_list[0:100])
                    #### start training
                    logging.info(f"Process step1 and step2 of batch index {batch_idx} and length {len(batch_data)}")
                    args.generate_output_file = f"{args.save_path}/train_data_{batch_idx}_{len(batch_data)}_{args.solver_name}_{args.prover_name}_generate.json"
                    result = process_step1_and_step2(batch_data, args)
                    cpu_queue.put((batch_idx, result))
                except Exception as e:
                    import traceback
                    logging.error(f"GPU worker error: {e}")
                    print(traceback.format_exc())
                    raise e
                finally:
                    gpu_queue.task_done()
        
        def cpu_worker(cpu_queue, result_queue):
            while True:
                batch_idx, processed_batch = cpu_queue.get()
                if processed_batch is None:  
                    break
                try:
                    logging.info(f"Process step3 of batch index {batch_idx} and length {len(processed_batch)}")
                    args.check_output_file = f"{args.save_path}/train_data_{batch_idx}_{len(processed_batch)}_{args.solver_name}_{args.prover_name}_check.json"
                    result = process_step3(processed_batch, args)
                    result_queue.put(result)
                except Exception as e:
                    logging.error(f"CPU worker error: {e}")
                    raise e
                finally:
                    cpu_queue.task_done()
        
        gpu_process = Process(target=gpu_worker, args=(gpu_queue, cpu_queue))
        cpu_process = Process(target=cpu_worker, args=(cpu_queue, result_queue))
        gpu_process.start()
        cpu_process.start()
        
        num_batches = len(data_list) // args.batch_size + 1
        logging.info(f"Total {num_batches} batches to process")
        for i in range(num_batches):
            start_idx = i * args.batch_size
            end_idx = (i + 1) * args.batch_size
            batch_data = data_list[start_idx:end_idx]
            
            if batch_data:
                gpu_queue.put((i, batch_data))
                
        
        gpu_queue.join()
        
        gpu_queue.put((None, None))
        cpu_queue.put((None, None))
        
        gpu_process.join()
        cpu_process.join()
        
        while not result_queue.empty():
            total_results.extend(result_queue.get())
    
    return total_results
    
def main():
    parser = argparse.ArgumentParser(description="Batch generate counter-examples using configurable solver.")
    parser.add_argument("--solver_name", type=str, default="mistral", help="Solver to use")
    parser.add_argument("--prover_name", type=str, default="deepseek_v15_rl", help="Prover to use")
    parser.add_argument("--solver_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.3", help="Path to the solver")
    parser.add_argument("--prover_path", type=str, default="deepseek-ai/DeepSeek-Prover-V1.5-RL", help="Path to the prover")
    parser.add_argument("--num_problems", type=int, default=-1, help="Number of problems to generate")
    parser.add_argument("--solver_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--prover_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--default_header", type=int, default=0, help="Whether to use default header")
    parser.add_argument("--gpu", type=int, default=1, help="Number of GPUs to use")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter")
    parser.add_argument("--alpha", type=float, default=0.2, help="Alpha for updating model parameters")
    ## training parameters
    parser.add_argument("--training_epochs", type=int, default=1, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=20000, help="Batch size")
    parser.add_argument("--dataset_path", type=str, default="datasets/test_data.json", help="Path or name of the dataset to use")
    parser.add_argument("--save_dir", type=str, default="save/", help="Path to the output file")
    ##
    parser.add_argument("--check_hypothesis", type=bool, default=True, help="Whether to check hypothesis")
    args = parser.parse_args() # type: ignore
    
    # process dataset
    # data_name = args.dataset_path.split("/")[-1].split(".")[0]
    if args.dataset_path.endswith(".json") or args.dataset_path.endswith(".jsonl"):
        dataset = load_dataset("json", data_files=args.dataset_path, split="train")
        logging.info(f"Local dataset: {args.dataset_path}")
    else:
        dataset = load_dataset(args.dataset_path, split="train")
        logging.info(f"HuggingFace dataset: {args.dataset_path}")
        
    # print model info and generate config
    logging.info(f"Solver model: {args.solver_path}")
    logging.info(f"Prover model: {args.prover_path}")
    logging.info(f"Counterexample generate {args.solver_k} times for each problem")
    logging.info(f"Formalproof generate {args.prover_k} times for each problem")
    
    # save init model
    if os.path.exists("./models/RL_full_sft_counterexample_solve"):
        shutil.rmtree("./models/RL_full_sft_counterexample_solve")
    if os.path.exists("./models/RL_full_sft_formalproof_generate"):
        shutil.rmtree("./models/RL_full_sft_formalproof_generate")
    
    shutil.copytree(args.solver_path, f"./models/RL_full_sft_counterexample_solve")
    shutil.copytree(args.prover_path, f"./models/RL_full_sft_formalproof_generate")
    args.solver_path = f"./models/RL_full_sft_counterexample_solve"
    args.prover_path = f"./models/RL_full_sft_formalproof_generate"    

    setup_logger(f"{args.save_dir}/train.log")
    for epoch in range(args.training_epochs):
        logging.info(f"Training epoch {epoch}...")
        
        os.makedirs(f"{args.save_dir}/train_data_{epoch}", exist_ok=True)
        args.save_path = f"{args.save_dir}/train_data_{epoch}"
        
        ## step1 + step2 + step3
        data_list = preprocess_data_step1(dataset, args)
        total_results = process_batches_pipeline(data_list, args)
        
        # step4
        process_step4(total_results, args)
        
        logging.info(f"Training epoch {epoch} completed.")
        # copy model to save_dir
        shutil.copytree("./models/RL_full_sft_counterexample_solve", f"{args.save_path}/RL_Epoch{epoch}_full_sft_counterexample_solve")
        shutil.copytree("./models/RL_full_sft_formalproof_generate", f"{args.save_path}/RL_Epoch{epoch}_full_sft_formalproof_generate")

    return 

if __name__ == "__main__":
    main()

