#!/usr/bin/env python3
"""
Central utilities: LLM client, argument parsers, file helpers, citation & table helpers.
All hard-coded paths removed – rely on ANON_PROJECT_ROOT or relative paths only.
"""

import os
import re
import json
import shutil
import argparse
import requests
from loguru import logger

# ---------- LLM client --------------------------------------------------------
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  # kept for compatibility
MAX_NUM_TOKENS = 1000

# fmt: off
AVAILABLE_PROXY_MODELS = [
]
# fmt: on


class ChatAgent:
    """
    Thin wrapper over HTTP chat completions.
    End-points and keys are passed via environment variables; no hard-coded URLs.
    """

    def __init__(self, model_name: str):
        self.model_name = model_name
        # choose end-point
        if model_name in AVAILABLE_PROXY_MODELS:
            self.url = os.getenv("PROXY_URL", "")
            self.headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {os.getenv('PROXY_KEY', '')}",
            }
        else:
            self.url = os.getenv("ROUTER_URL", "")
            self.headers = {"Content-Type": "application/json"}

    # ------------------------------------------------------------------ #
    # low-level chat
    # ------------------------------------------------------------------ #
    def chat(
        self,
        prompt: str,
        retry: int = 8,
        temperature: float = 0.75,
        max_tokens: int = MAX_NUM_TOKENS,
        n_responses: int = 1,
        system_message: str | None = None,
    ) -> str:
        if retry == 0:
            return "sorry, i can't answer"

        sys_msg = system_message or "you are a helpful assistant and an expert in research."
        messages = [
            {"role": "system", "content": sys_msg},
            {"role": "user", "content": prompt},
        ]

        payload = {
            "model": self.model_name,
            "messages": messages,
            "temperature": temperature,
            "n": n_responses,
            "stop": None,
            "seed": 0,
        }
        if max_tokens != MAX_NUM_TOKENS:
            payload["max_tokens"] = max_tokens

        logger.debug(f"Chat request: {payload}")
        resp = requests.post(self.url, headers=self.headers, json=payload, timeout=120)
        if resp.status_code != 200:
            logger.warning(resp.text)
            return self.chat(prompt, retry - 1, temperature, max_tokens, n_responses, system_message)

        data = resp.json()
        return data["choices"][0]["message"]["content"]

    # ------------------------------------------------------------------ #
    # high-level helpers
    # ------------------------------------------------------------------ #
    def chat_with_json_retry(self, prompt: str, max_retries: int = 5) -> dict | None:
        """Extract JSON block from LLM reply; retry on failure."""
        for attempt in range(max_retries, 0, -1):
            try:
                raw = self.chat(prompt)
                block = re.search(r"```json\n(.*?)\n```", raw, re.DOTALL)
                if not block:
                    raise ValueError("No JSON block found.")
                return json.loads(block.group(1))
            except Exception as e:
                logger.warning(f"JSON parse failed ({attempt} left): {e}")
        return None

    def chat_with_latex_retry(self, prompt: str, max_retries: int = 5) -> str | None:
        """Extract LaTeX block from LLM reply; retry on failure."""
        for attempt in range(max_retries, 0, -1):
            try:
                raw = self.chat(prompt)
                block = re.search(r"```latex\n(.*?)\n```", raw, re.DOTALL)
                if not block:
                    raise ValueError("No LaTeX block found.")
                return block.group(1)
            except Exception as e:
                logger.warning(f"LaTeX parse failed ({attempt} left): {e}")
        return None


# ---------- argument parsers --------------------------------------------------
class ArgParser:
    """Re-usable argument builder."""

    def __init__(self, description: str = "AI-Scientist helper"):
        self.parser = argparse.ArgumentParser(description=description)

    def add_topic_args(self):
        self.parser.add_argument("--topic", help="Main research topic.")
        self.parser.add_argument("--topic_description", default="xxxxx", help="Short sub-topic phrase.")
        return self

    def add_model_args(self):
        self.parser.add_argument("--model_name", default="deepseek-v3", help="LLM to use.")
        return self

    def add_csv_args(self):
        self.parser.add_argument("--original_csv_path", default="", help="CSV with paper metadata.")
        return self

    def add_revise_args(self):
        self.parser.add_argument("--revision_name", default="revision", help="Revision folder suffix.")
        self.parser.add_argument("--model_used_for_review", default="iclr2025_conference", help="Review JSON folder.")
        self.parser.add_argument("--old_path_name", default="latex", help="Previous revision to copy.")
        self.parser.add_argument("--mode", choices=["normal", "force"], default="normal", help="Score-gate behaviour.")
        return self

    def parse(self):
        return self.parser.parse_args()


# ---------- file helpers ------------------------------------------------------
def copy_folder(src: str, dst: str) -> None:
    """Deep-copy a directory tree."""
    os.makedirs(dst, exist_ok=True)
    for item in os.listdir(src):
        s, d = os.path.join(src, item), os.path.join(dst, item)
        if os.path.isdir(s):
            copy_folder(s, d)
        else:
            shutil.copy2(s, d)


# ---------- LaTeX helpers -----------------------------------------------------
def insert_tables_for_references(llm_output: str, tables_already: list[str]) -> str:
    """
    Insert pre-generated tables next to their first \\ref{label} mention.
    tables_already: list of full table environments containing \\label{...}.
    """
    if not tables_already:
        return llm_output

    lookup = {m.group(1): tbl for tbl in tables_already if (m := re.search(r"\\label{([^}]+)}", tbl))}
    if not lookup:
        return llm_output

    paragraphs = llm_output.split("\n\n")
    updated = []

    for para in paragraphs:
        refs = re.findall(r"\\ref{([^}]+)}", para)
        for ref in refs:
            if ref in lookup and lookup[ref] not in llm_output:
                para = f"{lookup[ref]}\n\n{para}"
                break  # one insertion per paragraph
        updated.append(para)

    return "\n\n".join(updated)