import os
import sys
import queue
from datetime import datetime
from multiprocessing.pool import ThreadPool
from threading import Lock
from typing import List, Dict, Any, Tuple, Optional
from dotenv import load_dotenv
from openai import OpenAI


class v3_batch:
    """
    DeepSeek (official) OpenAI-compatible batch chat wrapper.

    Environment variables:
      - DEEPSEEK_API_KEY
      - DEEPSEEK_BASE_URL (e.g., https://api.deepseek.com)
      - DEEPSEEK_MODEL (e.g., deepseek-chat or deepseek-reasoner)
    """

    def __init__(self, debug: bool = False):
        self.debug = debug
        load_dotenv()
        api_key = os.getenv("DEEPSEEK_API_KEY")
        base_url = os.getenv("DEEPSEEK_BASE_URL")
        self.model_id = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
        if not api_key or not base_url:
            raise ValueError("Missing env: set DEEPSEEK_API_KEY and DEEPSEEK_BASE_URL")
        self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=24 * 3600)

    @staticmethod
    def _now_str() -> str:
        return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    @staticmethod
    def _safe_model_dump(obj: Any) -> Any:
        return obj.model_dump() if hasattr(obj, "model_dump") else obj

    def batch_generate(
        self,
        messages_list: List[List[Dict[str, str]]],
        num_workers: int,
        num_messages_per_worker: int,
    ) -> List[Dict[str, Any]]:
        total = len(messages_list)
        planned = num_workers * num_messages_per_worker
        if self.debug and planned != total:
            print(
                f"[WARN {self._now_str()}] planned({planned}) != total({total}), proceeding with total.",
                file=sys.stderr,
            )

        results: List[Optional[Dict[str, Any]]] = [None] * total
        results_lock = Lock()

        q: "queue.Queue[Optional[Tuple[int, Dict[str, Any]]]]" = queue.Queue()
        for idx, msgs in enumerate(messages_list):
            q.put((idx, {"model": self.model_id, "messages": msgs}))
        for _ in range(num_workers):
            q.put(None)

        def worker_fn(worker_id: int):
            if self.debug:
                print(f"[{self._now_str()}] Worker-{worker_id} START", flush=True)
            while True:
                item = q.get()
                if item is None:
                    if self.debug:
                        print(f"[{self._now_str()}] Worker-{worker_id} EXIT", flush=True)
                    q.task_done()
                    break
                idx, request = item
                try:
                    resp = self.client.chat.completions.create(**request)
                    answer = None
                    reasoning = None
                    try:
                        if getattr(resp, "choices", None):
                            msg = resp.choices[0].message
                            answer = getattr(msg, "content", None)
                            reasoning = getattr(msg, "reasoning_content", None)
                    except Exception:
                        answer = None
                        reasoning = None
                    packed = {
                        "raw_outputs": self._safe_model_dump(resp),
                        "reasoning": reasoning,
                        "answer": answer,
                    }
                except Exception as e:
                    packed = {
                        "raw_outputs": {"error": str(e)},
                        "reasoning": None,
                        "answer": None,
                    }
                    if self.debug:
                        print(
                            f"[{self._now_str()}] Worker-{worker_id} idx={idx} ERROR: {e}",
                            file=sys.stderr,
                        )
                finally:
                    with results_lock:
                        results[idx] = packed
                    q.task_done()

            if self.debug:
                print(f"[{self._now_str()}] Worker-{worker_id} END", flush=True)

        with ThreadPool(num_workers) as pool:
            for wid in range(num_workers):
                pool.apply_async(worker_fn, args=(wid,))
            pool.close()
            pool.join()

        for i in range(total):
            if results[i] is None:
                results[i] = {"raw_outputs": {"error": "empty"}, "reasoning": None, "answer": None}

        if self.debug:
            print(f"[{self._now_str()}] Batch END", flush=True)

        return results


