import re
import os
import ast
import json
import random
import argparse
from glob import glob
from prompts import * 
from utils import is_valid_question, Generator
from bing_search import fetch_page_content, extract_snippet_with_context

MAX_DOC_LEN = 1000

def load_batch_items(args):
    items = []
    for idx in range(args.start_idx, args.end_idx):   
        try:
            # search_data 
            file_name = glob(f"{args.search_results_dir}/{idx}_*.json")[0]
            with open(file_name, "r") as f:
                related_queries = json.load(f)["sub_questions"] # dict_keys(['question', 'snippet', 'title', 'link', 'top_results'])
            file_name = glob(f"{args.scraper_outputs_dir}/qa_{idx}_scraper_result.json")[0]
            with open(file_name, "r") as f:
                scraper_data = json.load(f)
            org_question = scraper_data["original_question"]
            org_answer = scraper_data["original_answer"]
            entity = scraper_data["steps"]["1_entities"]["output"]["entities_in_question"]
        except:
            continue
        items.append({"idx": idx, "org_question": org_question, "org_answer": org_answer, "entity": entity, "related_queries": related_queries})
    return items

def check_whether_related_queries_are_related_to_entity(batch_items, model):
    # RQ Quality Check - Including Harry Potter (Entity)
    prompts = []
    for item in batch_items:
        for rq in item["related_queries"]:
            prompt = prompt_rq_including_entity.format(
                reference_question=item["org_question"],
                entity=item["entity"],
                reference_answer=item["org_answer"],
                test_question=rq
            )
            prompts.append(prompt)
    responses = model.generate(prompts, max_tokens=128, repetition_penalty=1.0, apply_chat=True, enable_thinking=False)

    for item in batch_items:
        for idx, rq in enumerate(item["related_queries"]):
            response = responses.pop(0)
            item["related_queries"][idx]["idx"] = idx
            item["related_queries"][idx]["valid_related_queries"] = response.lower().strip() == "yes"
        # # If valid RQ <=2, do not consider the item
        # item["valid_related_queries"] = [i for i in item["related_queries"] if i["valid_related_queries"]]
        # if len(item["valid_related_queries"]) <= 1:
        #     for rq in item["related_queries"]:
        #         rq["valid_related_queries"] = False
    return batch_items
    
def get_answers_for_related_queries(batch_items, model, url_cache, top_k=3):
    # 1. Get all URLs
    all_urls = []
    for item in batch_items:
        for rq in item["related_queries"]:
            if not rq["valid_related_queries"]:
                continue
            urls = [i["link"] for i in rq["top_results"][:top_k]]
            for url in urls:
                if url not in url_cache:
                    all_urls.append(url)
    # 2. Fetch at once
    documents = fetch_page_content(all_urls, use_jina=False, jina_api_key=os.getenv("JINA_API_KEY"))  # {url: context}

    prompts = []
    for item in batch_items:
        for rq in item["related_queries"]:
            if not rq["valid_related_queries"]:
                continue
            urls = [i["link"] for i in rq["top_results"][:top_k]]
            snippets = [i["snippet"] for i in rq["top_results"][:top_k]]
            formatted_documents = ""
            for i, (url, snippet) in enumerate(zip(urls, snippets)):
                if url not in url_cache:
                    raw_context = documents.get(url, "")
                    snippet = snippet.replace('<b>', '').replace('</b>', '') if snippet else None
                    success, filtered_context = extract_snippet_with_context(raw_context, snippet, context_chars=MAX_DOC_LEN)
                    context = filtered_context if success else raw_context[:MAX_DOC_LEN*2]
                    url_cache[url] = {"url": url, "context": context, "snippet": snippet}
                    doc_info = {
                        "url": url,
                        "context": context,
                        "snippet": snippet
                    }
                else:
                    doc_info = url_cache[url]
                formatted_documents += f"**Web Page {i + 1}:**\n"
                formatted_documents += json.dumps(doc_info, ensure_ascii=False, indent=2) + "\n"
            prompt = prompt_for_webpage_reasonchain.format(
                search_query=rq["question"],
                document=formatted_documents,
                reference_entity=item["entity"]
            )
            prompts.append(prompt)
    responses = model.generate(prompts, max_tokens=1024, repetition_penalty=1.0, apply_chat=True, enable_thinking=False)
    
    for item in batch_items:
        for rq in item["related_queries"]:
            if not rq["valid_related_queries"]:
                continue
            response = responses.pop(0)
            desc = response.strip().strip("**Final Information**").strip(" ")
            rq["desc"] = desc
            # Exit - no helpful information found
            if desc.lower().strip() == "no helpful information found.":
                item["valid_related_queries"] = False
            
    return batch_items

def check_whether_related_queries_and_answers_are_related_to_answer(batch_items, model):
    prompts = []
    for item in batch_items:
        for rq in item["related_queries"]:
            if rq["valid_related_queries"]:
                prompt = prompt_rq_including_answer.format(
                    question=rq["question"],
                    answer=rq["desc"],
                    entity=item["org_answer"]
                )
                prompts.append(prompt)
    responses = model.generate(prompts, max_tokens=1024, repetition_penalty=1.0, apply_chat=True, enable_thinking=False)
    for item in batch_items:
        for rq in item["related_queries"]:
            if not rq["valid_related_queries"]:
                continue
            response = responses.pop(0)
            rq["valid_related_queries"] = response.lower().strip() == "no"
    
    for item in batch_items:
        valid_related_queries = [i for i in item["related_queries"] if i["valid_related_queries"]]
        valid_related_queries_ = []
        for rq in valid_related_queries:
            valid_related_queries_.append({"idx": rq["idx"], "question": rq["question"], "description": rq["desc"]})
        item["valid_related_queries"] = valid_related_queries_
        del item["related_queries"]
    return batch_items
           
def select_and_summarize_clues_from_related_queries(batch_items, model):
    prompts = []
    for item in batch_items:
        prompt = prompt_for_clue_selection.format(
            entity=item["entity"],
            input_list= "[\n" + ",\n".join([i["description"] for i in item["valid_related_queries"]]) + "\n]",
        )
        prompts.append(prompt)
    responses = model.generate(prompts, max_tokens=1024, repetition_penalty=1.0, apply_chat=True, enable_thinking=False)
    for item in batch_items:
        response = responses.pop(0)
        try:
            response_ = response.replace("**Selected Clues:**", "").strip()
            response_ = response_.split("\n")[0].strip()
            selected_indices = ast.literal_eval(response_)
        except:
            print(item["idx"])
            print(response)
        # Randomly sample clues
        selected_indices = [i for i in selected_indices if i < len(item["valid_related_queries"])]
        n_clues = min(len(selected_indices), random.randint(3, 5))
        selected_indices = random.sample(selected_indices, n_clues)
        selected_clues = [item["valid_related_queries"][i]["description"] for i in selected_indices]
        item['selected_clues'] = selected_clues
        item['selected_indices'] = selected_indices
    
    # Summarize clues
    prompts = []
    for item in batch_items:
        formatted_selected_clues = "[\n" + ",\n".join(item["selected_clues"]) + "\n]"
        prompt = prompt_for_clue_summarization.format(
            entity=item["entity"],
            input_list=formatted_selected_clues,
        )
        prompts.append(prompt)
    responses = model.generate(prompts, max_tokens=1024, repetition_penalty=1.0, apply_chat=True, enable_thinking=False)
    for item in batch_items:
        response = responses.pop(0)
        try:
            response = response.replace("**Summarized Clues:** ", "").strip()
            response = response.replace("```json", "").replace("```", "")
            summarized_clues = ast.literal_eval(response)
        except:
            print(response)
        item['summarized_clues'] = summarized_clues
    return batch_items
 
 
def generate_and_validate_complex_and_integrated_questions(batch_items, model):
    # Step 4: Generate complex questions for items that need it
    prompts_complex = []
    idxs_complex = []
    for idx, item in enumerate(batch_items):
        if "complex_q" in item and len(item["complex_q"]) > 0 \
            and "complex_q_valid" in item and item["complex_q_valid"]:
            continue
        clues = item["summarized_clues"]
        input_list = "[\n" + ",\n".join(clues) + "\n]"
        prompt = prompt_for_complex_question_generation.format(
            entity=item["entity"],
            input_list=input_list
        )
        prompts_complex.append(prompt)
        idxs_complex.append(idx)
    if prompts_complex:
        responses_complex = model.generate(prompts_complex, max_tokens=1024, apply_chat=True, enable_thinking=False)
        for i, response in enumerate(responses_complex):
            idx = idxs_complex[i]
            item = batch_items[idx]
            if "**Complex Question:**" in response:
                complex_q = response.split("**Complex Question:**")[1].strip().replace("**", "")
                valid = True
                for entity in item["entity"]:
                    if not is_valid_question(complex_q, entity):
                        valid = False 
                        break
            else:
                complex_q = ""
                valid = False
            item["complex_q"] = complex_q
            item["complex_q_valid"] = valid
            
    # Step 5: Generate integrated questions for items that need it
    prompts_integration = []
    idxs_integration = []
    for idx, item in enumerate(batch_items):
        if "integrated_q" in item and item["integrated_q"] and \
            "integrated_q_valid" in item and item["integrated_q_valid"]:
            continue
        # Only proceed if complex question is valid
        if not (item["complex_q"] and item["complex_q_valid"]):
            continue
        complex_q = item["complex_q"]
        prompt = prompt_for_question_integration.format(
            question_1=item["org_question"],
            question_2=complex_q,
            entity=item["entity"]
        )
        prompts_integration.append(prompt)
        idxs_integration.append(idx)
    if prompts_integration:
        responses_integration = model.generate(prompts_integration, max_tokens=1024, apply_chat=True, enable_thinking=False)
        for i, response in enumerate(responses_integration):
            idx = idxs_integration[i]
            item = batch_items[idx]
            if "**Integrated Question:**" in response:
                integrated_q = response.split("**Integrated Question:**")[1].strip().replace("**", "")
                integrated_q = integrated_q[0].upper() + integrated_q[1:]
                valid = is_valid_question(integrated_q, item["org_answer"])
            else:
                integrated_q = ""
                valid = False
                item["complex_q_valid"] = False
            item["integrated_q"] = integrated_q
            item["integrated_q_valid"] = valid
    return batch_items


# 6. Save results
def save_results(items, output_dir):
    for item in items:
        output_path = os.path.join(output_dir, f"generate_result_{item['idx']}.json")
        result = {
            "idx": item["idx"],
            "original_question": item["org_question"],
            "original_answer": item["org_answer"],
            "entity": item["entity"],
            "selected_clues": item["selected_clues"],
            "selected_indices": item["selected_indices"],
            "valid_related_queries": item["valid_related_queries"],
            "summarized_clues": item["summarized_clues"],
            "complex_q": item.get("complex_q", ""),
            "complex_q_valid": item.get("complex_q_valid", False),
            "integrated_q": item.get("integrated_q", ""),
            "integrated_q_valid": item.get("integrated_q_valid", False),
        }
        with open(output_path, "w") as f:
            json.dump(result, f, indent=2)
        print(f"Saved: {output_path}")
        
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--start_idx", "-s", type=int, default=0)
    parser.add_argument("--end_idx", "-e", type=int, default=1000)
    parser.add_argument("--output_dir", "-o", type=str, default="nq_train/generate_questions")
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-32B")
    parser.add_argument("--max_tokens", type=int, default=8196)
    parser.add_argument("--max_retries", type=int, default=5)
    parser.add_argument("--search_results_dir", type=str, default="nq_train/search_results")
    parser.add_argument("--scraper_outputs_dir", type=str, default="nq_train/scraper_outputs")
    args = parser.parse_args()
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    
    cache_path = f"{output_dir}/url_cache.json"
    if os.path.exists(cache_path):
        with open(cache_path, "r") as f:
            url_cache = json.load(f)
    else:
        url_cache = {}
    
    model = Generator(args.model)
    batch_items = load_batch_items(args)
    batch_items = check_whether_related_queries_are_related_to_entity(batch_items, model)
    batch_items = get_answers_for_related_queries(batch_items, model, url_cache)
    batch_items = check_whether_related_queries_and_answers_are_related_to_answer(batch_items, model)

    max_retries = args.max_retries
    done_items = []
    done_indices = set()
    for _ in range(max_retries):
        batch_items = select_and_summarize_clues_from_related_queries(batch_items, model)
        batch_items = generate_and_validate_complex_and_integrated_questions(batch_items, model)
        # integrated_q_valid가 False인 아이템이 남아있으면 반복
        valid_items = [item for item in batch_items if item.get("integrated_q_valid", False) or item.get("complex_q_valid", False)]
        invalid_items = [item for item in batch_items if not item.get("integrated_q_valid", False)]
        for item in valid_items:
            if item['idx'] not in done_indices:
                done_items.append(item)
                done_indices.add(item['idx'])
        if len(invalid_items) == 0:
            break
        batch_items = invalid_items
    save_results(done_items, args.output_dir)
    
    with open("cache/url_cache.json", "w") as f:
        json.dump(url_cache, f, indent=2)
    
if __name__ == "__main__":
    main() 