import os
import re
import time

import pandas as pd
from openai import OpenAI

from dotenv import load_dotenv
load_dotenv(override=True)

# New SDK: supports client.close() for clean shutdown (no gRPC/semaphore leak)
from google import genai
from google.genai import types

import helper_functions

# Lazy-initialized Gemini client; closed via close_gemini_client()
_gemini_client = None


def _get_gemini_client():
    """Return the shared Gemini client, creating it on first use."""
    global _gemini_client
    if _gemini_client is None:
        api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
        _gemini_client = genai.Client(api_key=api_key)
    return _gemini_client


def close_gemini_client():
    """Close the Gemini client and release resources (connections, etc.). Call on exit when using Gemini."""
    global _gemini_client
    if _gemini_client is not None:
        try:
            if hasattr(_gemini_client, "close"):
                _gemini_client.close()
            else:
                # google-genai 1.7: Client does not expose close(); close underlying httpx client
                api_client = getattr(_gemini_client, "_api_client", None)
                if api_client is not None:
                    httpx_client = getattr(api_client, "_httpx_client", None)
                    if httpx_client is not None and not getattr(httpx_client, "is_closed", True):
                        httpx_client.close()
        finally:
            _gemini_client = None


def generate(row: pd.Series = None, sys_prompt: str = "", prompt_type: str = None,
             client=None, model_name: str = "", preformatted_prompt: str = None,
             max_retries: int = 3, temperature: float = 0):
    """
    Generate a single Likert rating from 1 to 5.
    """
    # Prepare prompt
    if preformatted_prompt is not None:
        prompt = preformatted_prompt
    else:
        prompt = helper_functions.format_prompt(prompt_type, row)

    retries = 0

    while retries < max_retries:
        try:
            # GPT / OpenAI API call
            if model_name.startswith("gpt-"):
                client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
                response = client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": sys_prompt},
                        {"role": "user", "content": prompt + "\n\nRespond with a single integer from 1 to 5. No explanation."}
                    ],
                    temperature=temperature
                )
                raw = response.choices[0].message.content.strip()

                # Check if raw is a digit between 1 and 5
                if raw.isdigit() and 1 <= int(raw) <= 5:
                    return int(raw)
                raise ValueError(f"Invalid numeric response: {raw}")


            # Gemini API call
            elif model_name.startswith("gemini"):
                gemini_client = _get_gemini_client()
                full_prompt = (
                    f"{sys_prompt}\n\n{prompt}\n\n"
                    "Output only a single number between 1 and 5. Do not include any text, explanation, quotes, or formatting. Just the digit. For example: 3"
                )
                response = gemini_client.models.generate_content(
                    model="gemini-2.5-pro",
                    contents=full_prompt,
                    config=types.GenerateContentConfig(temperature=temperature),
                )
                raw = (response.text or "").strip()
                match = re.search(r"\b([1-5])\b", raw)
                if match:
                    return int(match.group(1))
                match_fallback = re.search(r"([1-5])", raw)
                if match_fallback:
                    return int(match_fallback.group(1))
                raise ValueError(f"Gemini: Could not parse valid Likert score: '{raw}'")

            else:
                raise ValueError(f"Unsupported model: {model_name}")

        except Exception as e:
            print(f"ERROR calling {model_name}: {repr(e)}")
            retries += 1
            if retries == max_retries:
                return None
            time.sleep(1)

    return None
