import openai
from loguru import logger
from tqdm import tqdm
from typing import *
from multiprocessing import Pool
import time
import random

Timeseries = List[List[float]]

class AskGPTAPI:
    """
    A thin wrapper for OpenAI-compatible Chat Completions API that supports:
      - Text-only queries (backward compatible)
      - Optional timeseries inputs appended to the first user message's `content`
        as [{"timeseries": ts1}, {"timeseries": ts2}, ...]
    """

    def __init__(self, api_key: str, base_url: str, model: str, num_workers: int = 4, timeout: int = 60):
        self.api_key = api_key
        self.base_url = base_url
        self.model = model
        self.num_workers = num_workers
        self.timeout = timeout
        self.client = openai.OpenAI(api_key=api_key, base_url=base_url)

    # -------- Internal helpers --------
    @staticmethod
    def _build_content(question: str, timeseries: Optional[Timeseries]) -> List[Dict[str, Any]]:
        """
        Build the message content array. If `timeseries` is provided, append
        [{"timeseries": ts}, ...] after the text block.
        """
        content: List[Dict[str, Any]] = [{"type": "text", "text": question}]
        if timeseries is not None:
            # Validate and coerce to list-of-list-of-floats
            if not isinstance(timeseries, list):
                raise TypeError("timeseries must be a list of sequences (list of list of floats).")
            for idx, ts in enumerate(timeseries):
                if not isinstance(ts, (list, tuple)):
                    raise TypeError(f"timeseries[{idx}] must be a list/tuple of floats.")
                # (Optional) light validation: all numeric
                try:
                    _ = [float(v) for v in ts]
                except Exception as e:
                    raise TypeError(f"timeseries[{idx}] contains non-numeric values: {e}")
                content.append({"timeseries": list(ts)})
        return content

    # -------- Public single-call API --------
    def ask_api(self, question: str, timeseries: Optional[Timeseries] = None, thinking: bool=False) -> str:
        """
        Send a single request.
        - question: plain text prompt
        - timeseries: optional list of time series, each a list[float]
          Example: [[...ts1...], [...ts2...]]
        """
        messages = [{
            "role": "user",
            "content": self._build_content(question, timeseries)
        }]

        timeout_cnt = 0
        while True:
            if timeout_cnt > 20:
                logger.error("Too many timeout!")
                raise RuntimeError("Too many timeout!")
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    timeout=self.timeout,
                    extra_body={
                        "thinking": {
                            "type": "disabled" if not thinking else "enabled"
                        }
                    }
                )
                break
            except Exception as err:
                if 'limit' in str(err).lower():
                    # logger.error(err)
                    time.sleep(random.randint(20, 50))
                else:
                    logger.error(err)
                    logger.error("API timeout, trying again...")
                    timeout_cnt += 1

        return response.choices[0].message.content

    # -------- Batch API --------
    def batch_ask_api(
        self,
        questions: List[Union[str, Tuple[str, Optional[Timeseries]]]],
        use_tqdm: bool = True,
        thinking: bool = False
    ) -> List[str]:
        """
        Batch call with multiprocessing.
        Each entry in `questions` can be:
          - "What is ... ?"                     # text-only
          - ("Analyze these TS", [[...], [...]]) # with timeseries
          - ("Just text", None)                 # explicit None keeps old behavior
        """
        # Prepare data for the worker function
        worker_data = []
        for item in questions:
            worker_data.append((item, self.api_key, self.base_url, self.model, self.timeout, thinking))
        
        with Pool(processes=self.num_workers) as pool:
            if use_tqdm:
                results = list(tqdm(pool.imap(_ask_api_worker_func, worker_data), total=len(questions)))
            else:
                results = list(pool.imap(_ask_api_worker_func, worker_data))
        return results


# -------- Standalone worker function for multiprocessing --------
def _ask_api_worker_func(args) -> str:
    """
    Standalone worker function that can be pickled for multiprocessing.
    Args: (item, api_key, base_url, model, timeout)
    """
    item, api_key, base_url, model, timeout, thinking = args
    
    # Create a new client instance in the worker process
    client = openai.OpenAI(api_key=api_key, base_url=base_url)
    
    # Determine question and timeseries
    if isinstance(item, tuple):
        question, timeseries = item
    elif isinstance(item, str):
        question, timeseries = item, None
    else:
        raise TypeError("Each item must be either a string question or a (question, timeseries) tuple.")
    
    # Build content
    content = [{"type": "text", "text": question}]
    if timeseries is not None:
        # Validate and coerce to list-of-list-of-floats
        if not isinstance(timeseries, list):
            raise TypeError("timeseries must be a list of sequences (list of list of floats).")
        for idx, ts in enumerate(timeseries):
            if not isinstance(ts, (list, tuple)):
                raise TypeError(f"timeseries[{idx}] must be a list/tuple of floats.")
            # (Optional) light validation: all numeric
            try:
                _ = [float(v) for v in ts]
            except Exception as e:
                raise TypeError(f"timeseries[{idx}] contains non-numeric values: {e}")
            content.append({"timeseries": list(ts)})
    
    # Make the API call
    messages = [{"role": "user", "content": content}]
    
    timeout_cnt = 0
    while True:
        if timeout_cnt > 20:
            logger.error("Too many timeout!")
            raise RuntimeError("Too many timeout!")
        try:
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                timeout=timeout,
                extra_body={
                    "thinking": {
                        "type": "disabled" if not thinking else "enabled"
                    }
                }
            )
            break
        except Exception as err:
            if 'limit' in str(err).lower():
                # logger.error(err)
                time.sleep(random.randint(20, 50))
            else:
                logger.error(err)
                logger.error("API timeout, trying again...")
                timeout_cnt += 1

    return response.choices[0].message.content
