import os
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import List, Dict, Any

from dotenv import load_dotenv
from openai import OpenAI


class close_model:
    """
    Official OpenAI Chat Completions client wrapper for batched generation.

    Environment variables:
      - OPENAI_API_KEY: API key
      - OPENAI_BASE_URL (optional): custom base URL if needed
    """

    def __init__(self, debug: bool = False, model_id: str | None = None, base_url: str | None = None):
        self.debug = debug
        self.model = model_id
        load_dotenv(override=True)
        api_key = os.getenv("OPENAI_API_KEY")
        base_url = base_url or os.getenv("OPENAI_BASE_URL")
        if not api_key:
            raise ValueError("Missing env: set OPENAI_API_KEY for API access")
        self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=1800)
        self.temperature = 0.6

    def _log(self, msg: str):
        if self.debug:
            print(f"[{datetime.now():%Y-%m-%d %H:%M:%S}] {msg}", flush=True)

    def _generate_single(self, messages: List[Dict[str, str]]):
        retries = 3
        response = None
        for attempt in range(retries):
            try:
                if not isinstance(messages, list) or len(messages) == 0 or not all(isinstance(m, dict) for m in messages):
                    raise ValueError("messages must be a non-empty list of dicts")
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=self.temperature,
                )
                answer = response.choices[0].message.content if response.choices else ""
                return {
                    "raw_outputs": response.model_dump(),
                    "answer": (answer or "").strip(),
                }
            except Exception as e:
                self._log(f"Attempt {attempt + 1} failed: {e}")
                time.sleep(random.uniform(1, 2) if attempt == 0 else random.uniform(10, 20))
                if attempt == retries - 1:
                    if response is not None:
                        return {"raw_outputs": response.model_dump(), "answer": None}
                    return {"raw_outputs": {"error": str(e)}, "answer": "No Answer! \\boxed{No Answer}"}

    @staticmethod
    def _split_with_index(indexed_data, num_workers, num_messages_per_worker):
        assert num_workers * num_messages_per_worker == len(indexed_data), "length mismatch"
        return [indexed_data[i * num_messages_per_worker : (i + 1) * num_messages_per_worker] for i in range(num_workers)]

    def batch_generate(self, messages_list: List[List[Dict[str, str]]], num_workers: int, num_messages_per_worker: int):
        indexed_data = list(enumerate(messages_list))
        chunks = self._split_with_index(indexed_data, num_workers, num_messages_per_worker)

        def _worker(worker_id: int, chunk):
            self._log(f"Worker {worker_id} start: {len(chunk)} messages")
            out = []
            for local_i, (idx, messages) in enumerate(chunk, 1):
                self._log(f"Worker {worker_id} msg {local_i}/{len(chunk)} (global {idx}) START")
                res = self._generate_single(messages)
                self._log(f"Worker {worker_id} msg {local_i}/{len(chunk)} (global {idx}) DONE")
                out.append((idx, res))
            self._log(f"Worker {worker_id} all done")
            return out

        results_pairs = []
        with ThreadPoolExecutor(max_workers=num_workers) as pool:
            futures = [pool.submit(_worker, wid, chunk) for wid, chunk in enumerate(chunks, 1)]
            for fut in as_completed(futures):
                results_pairs.extend(fut.result())

        results_pairs.sort(key=lambda x: x[0])
        ordered_results = [r for _, r in results_pairs]
        assert len(ordered_results) == len(messages_list), "result length mismatch"
        return ordered_results


