import os
import json
import asyncio
import time
from pathlib import Path
from typing import List, Dict, Any, Optional
from tqdm.asyncio import tqdm
import traceback 
import pdb

from agentscope.model import DashScopeChatModel, OpenAIChatModel

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODEL_NAME = "gpt-5-2025-08-07"
BATCH_SIZE = 60  
TEMPERATURE = 0.7
MAX_TRIES_PER_BATCH = 50
API_CALL_INTERVAL = 5.0  

ORIGINAL_TASK_FILE = "../deep_research_bench/data/prompt_data/query.jsonl"
PROCESSED_DATA_FILE = "/data/processed_data.jsonl"
OUTPUT_FILE = "/data/raw_query.jsonl"

# ===== Model initialization =====
if "qwen" in MODEL_NAME:
    MODEL = DashScopeChatModel(
        api_key=API_KEY,
        model_name=MODEL_NAME,
        enable_thinking=False,
        stream=False,
    )
else:
    MODEL = OpenAIChatModel(
        api_key=API_KEY,
        model_name=MODEL_NAME,
        stream=False,
        client_args={
            "base_url": os.environ.get("OPENAI_API_KEY"),
        },
    )

# ========== Async LLM Calling ==========
async def call_model_async(messages: list) -> str:
    """Async call to LLM model and return plain text."""
    
    try:
        res = await MODEL(messages=messages, temperature=TEMPERATURE)
    except TypeError:
        res = await MODEL(messages, temperature=TEMPERATURE)

    # Handle streaming response
    try:
        agen = getattr(res, "__aiter__", None)
        if callable(agen):
            parts = []
            async for chunk in res:
                content = getattr(chunk, "content", None)
                if isinstance(content, list):
                    for blk in content:
                        if isinstance(blk, dict):
                            t = blk.get("text") or blk.get("content")
                        else:
                            t = getattr(blk, "text", None) or getattr(blk, "content", None)
                        if t:
                            parts.append(str(t))
                elif isinstance(content, str):
                    parts.append(content)
            if parts:
                return "\n".join(parts).strip()
    except Exception:
        pass

    # Handle non-streaming response
    content = None
    try:
        content = res.content
    except Exception:
        content = None
    if content is not None:
        if isinstance(content, list):
            parts = []
            for blk in content:
                if isinstance(blk, dict):
                    t = blk.get("text") or blk.get("content")
                else:
                    t = getattr(blk, "text", None) or getattr(blk, "content", None)
                if t:
                    parts.append(str(t))
            if parts:
                return "\n".join(parts).strip()
        elif isinstance(content, str):
            return content.strip()

    # OpenAI compatible structure fallback
    if isinstance(res, dict):
        try:
            choices = res.get("choices") or []
            if choices:
                ch0 = choices[0]
                if isinstance(ch0, dict):
                    msg = ch0.get("message", {})
                    if isinstance(msg, dict) and isinstance(msg.get("content"), str):
                        return msg["content"].strip()
                    if isinstance(ch0.get("text"), str):
                        return ch0["text"].strip()
            for k in ("output", "data", "message", "content", "text"):
                v = res.get(k)
                if isinstance(v, str):
                    return v.strip()
        except Exception:
            pass

    return str(res)


async def call_model_with_retry(messages: List[Dict]) -> Optional[str]:
    """Calls the model with retry logic for JSON parsing and rate limiting."""
    for i in range(MAX_TRIES_PER_BATCH):
        try:
            response_text = await call_model_async(messages)
            return response_text

        except json.JSONDecodeError as e:
            print(f"Attempt {i + 1} failed: Invalid JSON format. Error: {e}. Retrying...")
            await asyncio.sleep(API_CALL_INTERVAL)
        except Exception as e:
            error_str = str(e).lower()
            print(f"Attempt {i + 1} failed with an unexpected error: {e}. Retrying...")
            traceback.print_exc() 
            await asyncio.sleep(API_CALL_INTERVAL)
    
    print("All retry attempts failed. Returning None.")
    return None

async def generate_finegrained_query_with_llm(original_query, qa_pairs):
    return ""

async def main_parallel():
    print("Loading data files...")
    with open(ORIGINAL_TASK_FILE, 'r', encoding='utf-8') as f:
        original_tasks = [json.loads(line) for line in f]
    with open(PROCESSED_DATA_FILE, 'r', encoding='utf-8') as f:
        process_datas = [json.loads(line) for line in f]
    
    processing_items = []
    print("Preparing all processing items...")
    for original_task, process_data in zip(original_tasks, process_datas):
        tree_input = process_data['finegrained_tree']
        all_paths_logs_input = process_data['trajectory']
        nodes_map = tree_input['nodes']
        original_id = original_task['id']
        original_query = process_data['original_query']
        simple_query = process_data['simple_query']
        missing_intent = process_data['missing_intent']

        for path_index, path_log in enumerate(all_paths_logs_input, 1):
            qa_pairs = []
            visited_question_nodes_ordered = []
            
            for i in range(1, len(path_log)):
                prev_step = path_log[i-1]
                current_step = path_log[i]
                question_id = current_step[0][-1]
                question_node = nodes_map.get(question_id)
                if not question_node or not question_node.get('options'):
                    continue
                visited_question_nodes_ordered.append(question_node)
                prev_to_visit = prev_step[1]
                current_to_visit = current_step[1]
                num_added = len(current_to_visit) - (len(prev_to_visit)-1)
                # print(num_added)
                chosen_next_ids = current_to_visit[:num_added]
                # print(chosen_next_ids)
                if num_added < 2:
                    chosen_answer_text = [question_node['options'][0]['text']]
                else:
                    chosen_answer_text = []
                # print(question_node)
                for option in question_node['options']:
                    # TODO
                    if num_added < 2:
                        if option['next_node_id'] == chosen_next_ids:
                            chosen_answer_text = [option['text']]
                            break
                    else:
                        if all(node_id in chosen_next_ids for node_id in option['next_node_id']):
                            chosen_answer_text.append(option['text'])
                qa_pairs.append({
                    "question": question_node['text'],
                    "answer": chosen_answer_text
                })
            
            missing_details = [{
                "inquiry": node['text'],
                "options": [opt['text'] for opt in node.get('options', [])]
            } for node in visited_question_nodes_ordered]

            item = {
                "qa_pairs": qa_pairs,
                "metadata": {
                    "id": f"{original_id}_{path_index}",
                    "topic": original_task['topic'],
                    "language": original_task['language'],
                    "original_query": original_query,
                    "simple_query": simple_query,
                    "missing_intent": missing_intent,
                    "missing_details": missing_details,
                    "choices": [qa_pair["answer"] for qa_pair in qa_pairs]
                }
            }
            processing_items.append(item)

    final_dataset = []
    print(f"Total items to process: {len(processing_items)}. Starting parallel processing with BATCH_SIZE={BATCH_SIZE}...")
    
    with tqdm(total=len(processing_items), desc="Generating Queries") as pbar:
        for i in range(0, len(processing_items), BATCH_SIZE):
            batch_items = processing_items[i:i + BATCH_SIZE]
            
            tasks = [
                generate_finegrained_query_with_llm(item['metadata']['original_query'], item['qa_pairs'])
                for item in batch_items
            ]
            
            batch_results = await asyncio.gather(*tasks, return_exceptions=True)
            
            for item, finegrained_query in zip(batch_items, batch_results):
                if isinstance(finegrained_query, Exception):
                    print(f"Error processing item with original_id {item['metadata']['id']}: {finegrained_query}")
                    traceback.print_exc()
                    continue
                
                path_data = item['metadata']
                path_data['finegrained_query'] = finegrained_query
                final_dataset.append(path_data)
                
            pbar.update(len(batch_items))
            
            if i + BATCH_SIZE < len(processing_items):
                await asyncio.sleep(API_CALL_INTERVAL)

    print(f"\nWriting {len(final_dataset)} successful results to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        for item in final_dataset:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    print("Processing complete.")

if __name__ == "__main__":
    asyncio.run(main_parallel())