# 文件名: utils.py

import os
import json
import logging
import time
from openai import OpenAI
from jinja2 import Template

# 导入新的配置文件
import config_rl as config

# --- Logger Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def load_prompt_template(prompt_name: str) -> Template:
    """从配置的PROMPT_DIR目录加载一个Jinja2模板文件。"""
    file_path = os.path.join(config.PROMPT_DIR, f"{prompt_name}.txt")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return Template(f.read())
    except FileNotFoundError:
        logging.error(f"Prompt template not found: {file_path}")
        raise
    except Exception as e:
        logging.error(f"Error loading prompt template {file_path}: {e}")
        raise

def call_llm(prompt: str, temperature: float = 0.7) -> str:
    """调用大语言模型API，并增加了重试逻辑。"""
    if not config.LLM_API_KEY:
        raise ValueError("LLM_API_KEY not found. Please check your .env file and config_rl.py.")

    max_retries = 3
    base_delay = 5

    for attempt in range(max_retries):
        try:
            client = OpenAI(api_key=config.LLM_API_KEY, base_url=config.LLM_API_BASE)
            response = client.chat.completions.create(
                model=config.LLM_MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
            )
            return response.choices[0].message.content
        except Exception as e:
            logging.error(f"Error calling LLM API on attempt {attempt + 1}/{max_retries}: {e}")
            if attempt < max_retries - 1:
                delay = base_delay * (2 ** attempt)
                logging.info(f"Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                logging.error("LLM API call failed after all retries.")
                return ""
    return ""

def parse_llm_json_output(raw_output: str) -> dict | None:
    """解析LLM可能返回的被markdown包围的JSON字符串。"""
    cleaned_output = raw_output.strip()
    if cleaned_output.startswith("```json"):
        cleaned_output = cleaned_output[7:]
    if cleaned_output.endswith("```"):
        cleaned_output = cleaned_output[:-3]
    cleaned_output = cleaned_output.strip()

    try:
        return json.loads(cleaned_output)
    except json.JSONDecodeError:
        logging.error(f"Failed to decode LLM JSON output: {cleaned_output}")
        return None