"""
step1_solve_counter_example.py
This script is used to generate counter-examples for mathematical statements using configurable solvers.
"""

import argparse
import logging
from typing import Any, Dict, List, Callable
import random

from datasets import load_dataset
from train_utils import load_data, save_json, get_solver
from logger import setup_logger

logger = logging.getLogger(__name__)

def preprocess_data(dataset, args) -> List[Dict]:
    """
    Preprocess the data: limit number of problems, convert to list.
    """
    data_list = dataset.to_list()
    random.shuffle(data_list)
    total_data_size = len(data_list)
    if args.num_problems != -1:
        total_data_size = min(args.num_problems, total_data_size)
        data_list = data_list[:total_data_size]
    args.total_data_size = total_data_size
    logger.info(f"Data size after preprocessing: {total_data_size}")
    return data_list

def counter_example_generate(
    solver: Callable,
    data_list: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
    """
    Process the data in segments and generate counter-examples.
    """
    total_size = len(data_list)
    all_results = []
    try:
        results = solver(data_list)
    except Exception as e:
        logger.error(f"Solver failed: {e}")
    for i, result in enumerate(results):
        global_problem_id = i
        solver_input = result.get("solver_input", None)
        solver_outputs = result.get("solver_outputs", [])
        counter_examples = result.get("counter_examples", [])
        formatted_result = data_list[i]
        formatted_result.update({
            "problem_id": global_problem_id,
            "results": [
                {
                    "counter_example": example,
                    "solver_input": solver_input,
                    "solver_outputs": solver_output,
                }
                for (example, solver_output) in zip(counter_examples, solver_outputs)
                if example
            ],
        })
        all_results.append(formatted_result)
    # save_json(all_results, args.output_file)
    # logger.info(f"Counterexample results saved to {args.output_file}")
    return all_results

def main():
    parser = argparse.ArgumentParser(description="Batch generate counter-examples using configurable solver.")
    parser.add_argument("--solver_name", type=str, default="erdos", help="Solver to use")
    parser.add_argument("--solver_path", type=str, default="models/Mistral-7B-v0.3", help="Path to the solver")
    parser.add_argument("--num_problems", type=int, default=-1, help="Number of problems to generate")
    parser.add_argument("--solver_k", type=int, default=8, help="Number of samples (pass@k) to generate for each problem")
    parser.add_argument("--gpu", type=int, default=1, help="Number of GPUs to use")
    parser.add_argument("--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling parameter")
    parser.add_argument("--dataset_path", type=str, default="datasets/train_data.json", help="Path or name of the dataset to use")
    parser.add_argument("--save_dir", type=str, default="save/", help="Path to the output file")
    args = parser.parse_args() # type: ignore

    # setup logger
    log_path = "logs/step1_solve_counter_example.log"
    setup_logger(log_path)

    # Load dataset (local JSON or HuggingFace dataset)
    data_name = args.dataset_path.split("/")[-1].split(".")[0]
    if args.dataset_path.endswith(".json") or args.dataset_path.endswith(".jsonl"):
        dataset = load_dataset("json", data_files=args.dataset_path, split="train")
        logger.info(f"Local dataset: {args.dataset_path}")
    else:
        dataset = load_dataset(args.dataset_path, split="train")
        logger.info(f"HuggingFace dataset: {args.dataset_path}")

    data_list = preprocess_data(dataset, args)

    solver = get_solver(args.solver_name, args)
    args.output_file = (
        f"{args.save_dir}/step1_solve_counter_example_{data_name}_{args.total_data_size}_"
        f"{args.solver_name}_{args.solver_k}.json"
    )

    data_list = counter_example_generate(solver, data_list)

    logger.info("All counter-examples generated successfully.")
    logger.info("Example output:")
    logger.info(data_list[0])
    save_json(data_list, args.output_file)
    logger.info(f"Final results saved to {args.output_file}")

if __name__ == "__main__":
    main()
