#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
HF → Agent Runner (openai_orchestrator-based)
"""

import argparse
import importlib
import json
import os
import sys
from typing import Any, Callable, Dict, Optional, Tuple
import importlib.util
from pathlib import Path


from datasets import load_dataset
from tqdm import tqdm

from transformers import AutoTokenizer
import datetime
import fire
# -----------------------------
# Orchestrator import shim
# -----------------------------
def _import_orchestrator():
    """
    Import the orchestrator from openai_orchestrator.py.
    We try the most common names used in your codebase.
    """
    try:
        # Common alias seen in your previous code
        from LC_Agent.inference.openai_orchestrator import OpenAIOrchestratorVLLM as OpenAIOrchestrator
        return OpenAIOrchestrator

    except Exception as e:
        raise ImportError(
            "Could not import OpenAIOrchestrator(VLLM) from openai_orchestrator.py. "
            "Please ensure the file is on PYTHONPATH and exposes one of:\n"
            "  - OpenAIOrchestratorVLLM\n  - OpenAIOrchestrator\n"
            f"Original error: {e}"
        )


def _load_callable(spec: str) -> Callable:
    """
    Load a callable from a string.
    Accepts:
      - 'pkg.module:func'
      - '/abs/or/rel/path/to/file.py:func'
    Returns the function object.
    """
    if ":" not in spec:
        raise ValueError(f"Function spec must be 'module_or_path:func', got: {spec}")
    mod_part, func_name = spec.split(":", 1)

    if mod_part.endswith(".py") or "/" in mod_part or mod_part.startswith("."):
        # Treat as file path
        path = Path(mod_part).resolve()
        if not path.exists():
            raise FileNotFoundError(f"Function file not found: {path}")
        module_name = path.stem + "_dyn"
        spec_obj = importlib.util.spec_from_file_location(module_name, str(path))
        if spec_obj is None or spec_obj.loader is None:
            raise ImportError(f"Cannot load module from {path}")
        module = importlib.util.module_from_spec(spec_obj)
        spec_obj.loader.exec_module(module)
    else:
        # Treat as importable module
        module = importlib.import_module(mod_part)

    fn = getattr(module, func_name, None)
    if not callable(fn):
        raise AttributeError(f"'{func_name}' not found or not callable in {mod_part}")
    return fn


def read_json(file_path: str) -> dict:
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

# -----------------------------
# Agent call
# -----------------------------
def eval_hfds_openai(vllm_cfg: str,
                    model_name: str,
                    temperature: float,
                    top_p: float, 
                    top_k: int, 
                    max_turns_exp: int, 
                    max_context_exp: int,
 
                    trajectory_dir: str, 
                    output_fp: str,

                    dataset_name: str,
                    dataset_split: str,
                    item_to_question: str,   # e.g., "my_funcs:item_to_question_fn"
                    item_to_context: str,    # e.g., "my_funcs:item_to_context_fn"
                    item_to_answer: str,     # e.g., "my_funcs:item_to_answer_fn"
                    item_to_meta: str = None,      # e.g., "my_funcs:item_to_meta_fn"
                    output_postprocess: str = None,
                    model_answer_key: str = 'final_answer',
                    correct_answer_key: str = 'correct_answer',
                    tool_config_path: str = None,   # e.g., "path/to/tools.json"
                    system_prompt_name: str = None, # e.g., "TRAIN_SYSTEM_PROMPT"

                    tokenizer_path: str = "Qwen/Qwen3-8B",
                    max_output_tokens: int = 4096, 
                    max_turns_to_fail: int = 80) -> None:
    
    # Get orchestrator class
    OpenAIOrchestrator = _import_orchestrator()
    
    # Read endpoint config
    openai_cfg = read_json(vllm_cfg)
    print(f"[INFO] Loading tokenizer from {tokenizer_path}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    data = load_dataset(dataset_name, split=dataset_split)
    
    # Resolve external functions
    item_to_question_fn = _load_callable(item_to_question)
    item_to_context_fn  = _load_callable(item_to_context)
    item_to_answer_fn   = _load_callable(item_to_answer)
    item_to_meta_fn     = _load_callable(item_to_meta) if item_to_meta else None
    output_postprocess_fn = _load_callable(output_postprocess) if output_postprocess else None

    # Ensure output directory exists
    if output_fp:
        output_dir = os.path.dirname(output_fp)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
    
    # Ensure trajectory directory exists
    if trajectory_dir and not os.path.exists(trajectory_dir):
        os.makedirs(trajectory_dir, exist_ok=True)

    with open(output_fp, "w", encoding="utf-8") as fout:
        for idx, sample in enumerate(tqdm(data, desc="Processing samples")):
            # Extract context and question from sample
            context = item_to_context_fn(sample)
            question = item_to_question_fn(sample)
            correct_ans = item_to_answer_fn(sample)
            meta_dict = item_to_meta_fn(sample) if item_to_meta_fn else {}

            # Initialize orchestrator
            orchestrator = OpenAIOrchestrator(
                openai_cfg=openai_cfg,
                document_content=context,
                temperature=temperature,
                tokenizer=tokenizer,
                max_turns_exp=max_turns_exp,
                max_context_exp=max_context_exp,
                max_output_tokens=max_output_tokens,
                topp=top_p,
                topk=top_k,
                model_name=model_name,
                tool_config_path=tool_config_path,  # pass custom tool config path
                system_prompt_name=system_prompt_name  # pass system prompt name
            )
            
            last_payload = None
            try:
                last_payload = orchestrator.run(question, max_turns_to_fail=max_turns_to_fail)
                final_answer = orchestrator._extract_final_answer()
            except Exception as e:
                print(f"[ERROR] Orchestrator failed for sample {idx}: {e}")
                # last_payload = None
                final_answer = ""
            
            # Save trajectory if directory specified
            if trajectory_dir:
                timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
                traj_filename = f"sample_{idx}_t{temperature}_topp{top_p}_topk{top_k}_{timestamp}.json"
                orchestrator.save_trajectory(
                    out_dir=trajectory_dir, 
                    filename=traj_filename,
                    correct_answer=correct_ans,
                    meta_info={"last_payload": last_payload}
                )
            
            # Write result to output file
            result = {
                "dataset": dataset_name,
                "split": dataset_split,
                "model": model_name,
                "sample_id": idx,
                "question": question,
                # "context": context,
                "system_prompt": system_prompt_name,
                model_answer_key: final_answer,        # Variable as key
                correct_answer_key: correct_ans,       # Variable as key
                "api_call_count": getattr(orchestrator, 'api_call_counter', 0),
            }

            if meta_dict:
                result.update(meta_dict)
                if 'task_name' in meta_dict: # patch for niah
                    result['others'] = meta_dict
                    result['input'] = question
            
            fout.write(json.dumps(result) + "\n")
            fout.flush()

    if output_postprocess_fn:
        output_postprocess_fn(output_fp)
    
    print(f"Final output saved to {output_fp}")



# -----------------------------
# Main
# -----------------------------    

if __name__ == "__main__":
    fire.Fire()
