import os
import json
import argparse
import json_repair

from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any

prompt = """---Role---

You are a **Query-Focused Summary (QFS) QA pair generator**

---Goal---

I will provide you with **{topic_length} topics** along with their corresponding **topic descriptions and facts** (these are extracted from a source document {document_name}).
Generate question alogn with a thorough, detailed, and complete answer to the query that incorporates all relevant information from the Provide Topics and Facts. 
Do not simplify or summarize aggressively—aim for maximum coverage, including different perspectives, detailed facts, and nuanced insights. 
Do not make anything up or include information where the supporting evidence is not provided.

---Topics and Facts---
{content_data}

---Response Rules---

- Target format and length: Multiple Paragraphs
- Use markdown formatting with appropriate section headings
- Concisely cover all {topic_length} topics in a coherent and logically organized manner.
- Preserve factual accuracy without adding external information.

Response:
**Output in JSON format (do NOT add anything else)**:
{{
    "question": ...,
    "answer": ...
}}
"""

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate QFS QA pairs from sampled topics"
    )
    parser.add_argument(
        "--input",
        type=str,
        default="./picture_books/sample_topics.json",
        help="Path to the input topics JSON file",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="./qfs_qa_pairs.json",
        help="Path to save the generated QA pairs JSON",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="o4-mini",
        help="OpenAI model to use (e.g., gpt-4o, o4-mini)",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.2,
        help="Sampling temperature for generation",
    )
    parser.add_argument(
        "--doc-name",
        type=str,
        default="picture book",
        help="Document name to embed in the prompt",
    )
    parser.add_argument(
        "--max-count",
        type=int,
        default=None,
        help="Optional limit on number of topic groups to process (for testing)",
    )
    return parser.parse_args()

def read_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def write_json(obj, path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def build_prompt(topics: List[Dict[str, Any]], topic_length: int, document_name: str) -> str:
    block = []
    for topic in topics:
        topic_desc_str = f"Topic: {topic['topic']}\nDescription: {topic['description']}"
        facts_block = []
        for fact in topic['facts']:
            facts_block.append(fact['fact'])
        facts = sorted(list(set(facts_block)))
        facts_block = [f'- {fact}' for fact in facts]
        facts_block = "\n".join(facts_block)
        block.append(f'{topic_desc_str}\n\nFacts:\n{facts_block}')
    block = "\n---\n\n".join(block)
    return prompt.format(
        document_name=document_name,
        topic_length=topic_length,
        content_data=block,
    ), facts

def call_gpt(client: OpenAI, prompt: str, model: str, temperature: float) -> str:
    response = client.chat.completions.create(
        model=model,
        temperature=temperature,
        messages=[
            {"role": "system", "content": "You generate exactly one Query and one Answer."},
            {"role": "user", "content": prompt},
        ],
    )
    return response.choices[0].message.content.strip()

def parse_response(raw: str) -> Dict[str, str]:
    qa_pairs = json_repair.loads(raw)
    query = qa_pairs["question"]
    answer = qa_pairs["answer"]

    if not query or not answer:
        raise ValueError("Could not parse model output. Full text:\n" + raw)
    return {"query": query, "answer": answer}

def main():
    args = parse_args()

    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY environment variable not set")
    client = OpenAI(api_key=api_key)

    data_in = read_json(args.input)
    data_out = {}

    for topic_len, topics_groups in data_in.items():
        data_out.setdefault(topic_len, {})
        for topics_key, topics in tqdm(
            topics_groups.items(), desc=f"Processing topic_len={topic_len}", unit="doc"
        ):
            topic_length = int(topic_len)
            prompt, facts = build_prompt(topics, topic_length, args.doc_name)

            try:
                raw_reply = call_gpt(client, prompt, args.model, args.temperature)
                qa_pair = parse_response(raw_reply)
            except Exception as exc:
                print(f"⚠️  Error for {topics_key}: {exc}")
                qa_pair = {"query": "", "answer": "", "error": str(exc)}

            data_out[topic_len][topics_key] = {
                'qa_pair': qa_pair,
                'facts': facts,
                'topic_len': len(topic_len),
                'topics': topics,
            }

    write_json(data_out, args.output)
    print(f"\n✅ Done! All QA pairs saved to {args.output}")

if __name__ == "__main__":
    main()
