import os
import json
import multiprocessing as mp
import time
import argparse
import logging

from copy import deepcopy
from tqdm import tqdm
from datasets import load_dataset
from multiprocessing import Process, Manager

from mini_bench.agent import ChecklistGenerationAgent
from mini_bench.utils import load_json, save_json
from mini_bench.inference_utils import run_parallel_evaluation

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def generate_checklist_process(config, data):
    result = deepcopy(data)
    agent_config = config["agent_config"]
    prompt_type = config["prompt_type"]
    checklist_generation_agent = ChecklistGenerationAgent(agent_config)
    response_list, cost = checklist_generation_agent.generate_response(data, prompt_type=prompt_type, constraint_str_list=["CHECKLISTS"])
    result["response"] = response_list[0]

    return result, cost

def setup(args):
    # Create save directory
    os.makedirs(args.save_dir, exist_ok=True)

    # Load dataset
    dataset = load_dataset(args.dataset_name, split="train")
    # make dataset to list
    dataset = list(dataset)

    if args.num_data is not None:
        dataset = dataset[:args.num_data]

    config = {
        "agent_config": {
            "model_name": args.model_name,
            "base_url": args.base_url,
            "api_key": args.api_key,
            "temperature": args.temperature,
            "num_generate": args.num_generate,
        },
        "prompt_type": args.prompt_type,
    }
    
    return dataset, config

def save_final_results(args, results: list[dict]):
    final_results = []
    for result in results:
        # parse response
        if "[CHECKLISTS]" in result["response"]:
            checklist = result["response"].split("[CHECKLISTS]")[-1].strip()
        elif "<answer>" in result["response"]:
            checklist = result["response"].split("<answer>")[-1].split("</answer>")[0].strip()
        else:
            checklist = result["response"]

        final_results.append({
            "task_id": result["task_id"],
            "intent": result["intent"],
            "start_url": result["start_url"],
            "gt_checklist": result["gt_checklist"],
            "generated_checklist": checklist,
            "source_name": result["source_name"],
        })

    save_path = os.path.join(args.save_dir, f"{args.save_model_name}_checklist_generation.json")
    with open(save_path, 'w') as f:
        json.dump(final_results, f, indent=4)
    logger.info(f"Results successfully saved to {save_path}")

def main(args):
    dataset, config = setup(args)

    # Run parallel evaluation
    process_results, process_cost = run_parallel_evaluation(
        dataset=dataset,
        process_func=generate_checklist_process,
        config=config,
        num_workers=args.num_workers,
        description=f"Evaluating {args.model_name} on {args.dataset_name}"
    )

    logger.info(f"Evaluation finished. Total cost: {process_cost}")
    logger.info(f"Number of results collected: {len(process_results)}")

    # Save results
    save_final_results(args, process_results)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run parallel evaluation using a specified agent.")
    parser.add_argument("--dataset_name", type=str, default="webprm/mini_bench", help="Name of the dataset to load from Hugging Face datasets.")
    parser.add_argument("--save_dir", type=str, default="output/mini_bench", help="Directory to save the evaluation results.")
    parser.add_argument("--save_model_name", type=str, default="results.json", help="Filename for the saved results JSON file.")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini", help="Name of the model used for evaluation (e.g., gpt-4o-mini).")
    parser.add_argument("--base_url", type=str, default="https://api.openai.com/v1", help="Base URL for the API.")
    parser.add_argument("--api_key", type=str, required=True, help="API key for the service.")
    parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for generation.")
    parser.add_argument("--num_generate", type=int, default=1, help="Number of responses to generate per item.")
    parser.add_argument("--num_workers", type=int, default=10, help="Number of parallel worker processes.")
    parser.add_argument("--num_data", type=int, default=None, help="Number of data points to evaluate (None for all).")
    parser.add_argument("--prompt_type", type=str, default="default", help="Prompt type for checklist generation.")
    args = parser.parse_args()

    logger.info(f"Starting evaluation with config: {vars(args)}")
    start_time = time.time()
    main(args)
    end_time = time.time()
    logger.info(f"Total evaluation time: {end_time - start_time:.2f} seconds")