import os
import argparse
import json
from glob import glob
from utils import (
    get_qa_pairs, 
    call_openai_api, 
    call_openrouter_api, 
    call_local_api, 
    prompt_for_question_to_entities, 
    QueryScraper
)
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

from serpapi import GoogleSearch
serp_api_key = os.getenv("SERP_API_KEY")



def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", "-d", type=str, default="nq_train")
    parser.add_argument("--start_idx", "-s", type=int, default=0)
    parser.add_argument("--end_idx", "-e", type=int, default=100000)
    parser.add_argument("--api_call", "-a", type=str, default="local")
    parser.add_argument("--model", "-m", type=str, default="Qwen/Qwen3-32B")
    parser.add_argument("--output_folder", "-o", type=str, default="scraper_outputs")
    parser.add_argument("--search_results_folder", type=str, default="search_results")
    return parser.parse_args()

def process_qa_pair_scraper(args, qa_pair, call_api_func=None, idx=0, model=None):
    print("\n------------ Processing QA pair: ", idx, "------------\n")
    q, a = qa_pair["question"], qa_pair["answer"]
    result = {
        "original_question": q,
        "original_answer": a,
        "overall_status": "initiated",
        "steps": {s: {"status": "pending", "output": {}, "error": None, "raw_openai_response": None} for s in [
            "1_entities", "2_related_queries"]}
    }
    ##############################################################################
    # Step 0: Check if the question is already in the search_results directory
    ##############################################################################
    search_results_filename = f"scraper_outputs/qa_{idx}_scraper_result__.json"
    if os.path.exists(search_results_filename):
        print(f"Skipping QA pair {idx} because it already exists in the scraper_outputs directory")
        return None
    
    ##############################################################################
    # Step 1: Find entities in question and answer
    ##############################################################################
    # print("Step 1: Find entities in question and answer")
    entities_prompt = prompt_for_question_to_entities.format(question=q, answer=a)
    max_retries = 3
    for _ in range(max_retries):
        try:
            entities_response_str = call_api_func(entities_prompt, max_tokens=100, model=model)
            entities_json = json.loads(entities_response_str) 
            break
        except Exception as e:
            entities_json = {}
            continue
    print(entities_json)
    step1 = result["steps"]["1_entities"]
    step1["raw_openai_response"] = entities_response_str
    # entities_json = json.loads(entities_response_str) if entities_response_str else {}
    eq = entities_json.get("question", [])
    ea = entities_json.get("answer", [])
    step1["output"] = {"entities_in_question": eq, "entities_in_answer": ea}
    # print("Entities in question: ", eq)
    # print("Entities in answer: ", ea)
    if not eq or len(eq) >= 2 or not ea:
        result["overall_status"] = "failed_at_1_entities"
        step1["status"] = "failed"
        return result
    step1["status"] = "success"
    # print("Step 1 completed\n")
    
    ##############################################################################
    # Step 2: Related queries
    ##############################################################################
    # print("Step 2: Find related queries")
    step2 = result["steps"]["2_related_queries"]
    ent = eq[-1]
    step2["output"]["primary_entity_used"] = ent
    
    # Scrape related queries
    scraper = QueryScraper()
    including_entity = ent.lower()
    excluding_entities = [w for a in ea for w in a.lower().split()]
    
    ent_ = ent.split("/")[0] if "/" in ent else ent
    search_results_filename = f"{args.search_results_dir}/{idx}_{ent_.lower().replace(' ', '_')}.json"
    if os.path.exists(search_results_filename):
        related_data = json.load(open(search_results_filename))
        step2["status"] = "success"
        step2["output"]["search_results_file"] = search_results_filename
        return result
    else:
        related_data = scraper.find_related_queries(q, including_entity=including_entity, excluding_entities=excluding_entities, max_depth=0, max_queries=30, check_top_k=1, check_entity=True)
        with open(search_results_filename, "w") as f:
            json.dump(related_data, f, indent=2)
        step2["output"]["search_results_file"] = search_results_filename
        return result

def main():
    args = get_args()
    args.output_dir = f"{args.dataset_name}/{args.output_folder}"
    args.search_results_dir = f"{args.dataset_name}/{args.search_results_folder}"
    qa_pairs = get_qa_pairs(args.dataset_name, args.start_idx, args.end_idx)
    if args.api_call == "openai":
        call_api_func = call_openai_api
    elif args.api_call == "openrouter":
        call_api_func = call_openrouter_api
    elif args.api_call == "local":
        call_api_func = call_local_api
    else:
        raise ValueError(f"Invalid API call: {args.api_call}")
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    all_results_summary = []
    for i, qa_pair in enumerate(qa_pairs, start=args.start_idx):
        detailed_result = process_qa_pair_scraper(args, qa_pair, call_api_func, idx=i, model=args.model)
        if detailed_result is None:
            continue
        output_filename = os.path.join(output_dir, f"qa_{i}_scraper_result.json")
        try:
            with open(output_filename, "w") as f:
                json.dump(detailed_result, f, indent=2)
            all_results_summary.append({
                "qa_pair_index": i,
                "original_question": qa_pair["question"],
                "status": detailed_result["overall_status"],
                "output_file": output_filename
            })
        except Exception as e:
            all_results_summary.append({
                "qa_pair_index": i,
                "original_question": qa_pair["question"],
                "status": "failed_to_save_json",
                "error": str(e)
            })
    summary_filename = os.path.join(output_dir, f"_summary_of_scraper_runs_from_{args.start_idx}_to_{args.end_idx}.json")
    with open(summary_filename, "w") as f:
        json.dump(all_results_summary, f, indent=2)
    # print(f"\nScraper summary saved to {summary_filename}")

if __name__ == "__main__":
    main() 