#!/usr/bin/env python3
# LC_Agent/inference/novelqa_openai_runner.py
"""
Orchestrator Test Runner (OpenAI version)
Processes NovelQA documents through the *OpenAI* orchestrator.
"""


from datetime import datetime  # 从模块导入 datetime 类
import json
import os, argparse
from tqdm import tqdm
from transformers import AutoTokenizer
import fire

from LC_Agent.inference.openai_orchestrator import OpenAIOrchestratorVLLM as OpenAIOrchestrator
import time
from LC_Agent.orchestrator import ExecLogger        # same logger as before
from LC_Agent.utils import (
    read_json, render_one_question_nqa,
    find_bxx_txt_files, get_finished_questions_type
)

# -------------------------------------------------------------------

def eval_nqa_openai(cfg_path: str, 
                    temperature: float, 
                    topp: float, 
                    topk: int, 
                    max_turns_exp: int,
                    max_context_exp: int, 
                    books_base_dir: str = None, 
                    questions_base_dir: str = None, 
                    log_dir: str = None, 
                    result_dir: str = None, 
                    output_fp: str = None, 
                    test_oe: bool = False, 
                    model_name: str = "Qwen3-8B-SLM", 
                    max_output_tokens: int = 4096,
                    max_turns_to_fail: int = 200,
                    tool_config_path: str = None,
                    system_prompt_name: str = None,
                    close_book: bool = False,
                    book_meta_path: str = None
                    ) -> str:
    # 1) read endpoint config (simple JSON)
    openai_cfg = read_json(cfg_path)
    #   expected keys:  { "base_url": "...", "api_key": "..." }
    # 2) collect book files
    books = find_bxx_txt_files(books_base_dir)
    processed = os.listdir(result_dir) if os.path.exists(result_dir) else []
    finished  = get_finished_questions_type(result_dir, processed)
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

    if test_oe:
        print(f"[INFO] Testing both MCQ and OE questions")
    # Ensure output directory exists
    if output_fp:
        output_dir = os.path.dirname(output_fp)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

    if close_book:
        book_meta = read_json(book_meta_path) if book_meta_path else {}
        print(f"[INFO] Closed-book setting: only the question is provided.")

    with open(output_fp, "w", encoding="utf-8") as fout:
        for book in books:
            book_path = os.path.join(books_base_dir, book)
            q_path    = os.path.join(questions_base_dir, book.replace(".txt", ".json"))
            if not (os.path.exists(book_path) and os.path.exists(q_path)):
                print(f"[WARN] Missing files for {book}, skipping"); continue

            with open(book_path, "r", encoding="utf-8") as f:
                doc_text = f.read().strip()
            book_id = os.path.splitext(book)[0]
            questions = read_json(q_path)
            
            if close_book:
                book_title = book_meta[book_id]["title"]
                book_author = book_meta[book_id]["author"]

            for qid, qobj in tqdm(questions.items(), desc=f"{book_id} questions"):
                if test_oe:
                    qtype_to_test = (True, False)
                else:
                    qtype_to_test = (True,)
                
                for as_mcq in qtype_to_test:
                # for as_mcq in (True,):
                    qtype = "mc" if as_mcq else "oe"
                    # if qtype in finished.get(book_id, {}).get(qid, []):
                    #    continue

                    user_q, correct_ans = render_one_question_nqa(qobj, as_mcq=as_mcq)

                    output_log_dir   = f"{log_dir}/{book_id}/{qid}/{qtype}"
                    output_result_dir = f"{result_dir}/{book_id}/{qid}/{qtype}"

                    logger = ExecLogger(
                        log_dir   = output_log_dir,
                        results_dir = output_result_dir
                    )
                    
                    if close_book:
                        doc_text = ""
                        user_q = f"You are a literature professor. Try your best to answer the question based on your own knowledge of the book \"{book_title}\" by {book_author}.\n{user_q}"
                    
                    print(f"\n[Q] {user_q[:]}…")
                    
                    orchestrator = OpenAIOrchestrator(
                        openai_cfg       = openai_cfg,
                        document_content = doc_text,
                        temperature      = temperature,
                        tokenizer        = tokenizer,
                        logger           = logger,
                        model_name       = model_name,
                        topp             = topp,
                        topk             = topk,
                        max_turns_exp    = max_turns_exp,
                        max_context_exp  = max_context_exp,
                        max_output_tokens= max_output_tokens,
                        tool_config_path = tool_config_path,
                        system_prompt_name = system_prompt_name
                    )
                    try:
                        last_payload = orchestrator.run(user_q, max_turns_to_fail=max_turns_to_fail)
                    except Exception as e:
                        print(f"[ERROR] Orchestrator failed: {e}")
                        last_payload = None
                    meta = {
                        "api_call_count": orchestrator.api_call_counter,
                        "notes_count": len(orchestrator.state_manager.notes),
                        "last_payload": last_payload
                    }
                    #logger.save_inference_result(user_q, orchestrator, meta)
                    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                    trajectory_fn = f"{qid}_{qtype}_{book_id}_traj_t{temperature}_topp{topp}_topk{topk}_{max_context_exp}_{timestamp}.json"
                    results_fn = f"{qid}_{qtype}_{book_id}_results_t{temperature}_topp{topp}_topk{topk}_{max_context_exp}_{timestamp}.json"
                    orchestrator.save_trajectory(out_dir=log_dir, filename=trajectory_fn, correct_answer=correct_ans)
                    logger.save_final_result(orchestrator, user_q, correct_ans, meta, filename=results_fn)

                    fout.write(json.dumps({
                        "book_id": book_id,
                        "question_id": qid,
                        "question": user_q,
                        "correct_answer": correct_ans,
                        "final_answer": orchestrator._extract_final_answer(),
                        "trajectory_fp": os.path.join(log_dir, trajectory_fn),
                        "results_fp": os.path.join(output_result_dir, results_fn),
                    }) + "\n")
            fout.flush()
            # break
    print(f"Final output saved to {output_fp}")


if __name__ == "__main__":
    fire.Fire()
