from __future__ import annotations
import os,time
from typing import Dict, Tuple
from sentence_transformers import SentenceTransformer

try:
    from openai import OpenAI
    _HAS_OPENAI = True
except Exception:
    _HAS_OPENAI = False

try:
    from google import genai
    import time
    from google.genai import types
    from google.genai import errors
    _HAS_GEMINI = True
except Exception:
    _HAS_GEMINI = False

try:
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from vllm import LLM, SamplingParams
    _HAS_TRANSFORMERS = True
    _HAS_VLLM = True
except Exception:
    _HAS_TRANSFORMERS = False

GEMINI_API_KEY='YOUR_GEMINI_API_KEY'

OPENAI_API_KEY="YOUR_OPENAI_API_KEY"


_CACHE: Dict[str, Tuple[object, object]] = {}
_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "./model_cache")


MODEL_ID_MAP = {
    "qwen3-4b": "Qwen/Qwen3-4B-Instruct-2507",
    "llama3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
}

LOCAL_LLM_PATH = {"Qwen/Qwen3-4B-Instruct-2507": os.getenv("QWEN3_4B_PATH", "Qwen/Qwen3-4B-Instruct-2507"),
                    "meta-llama/Llama-3.1-8B-Instruct": os.getenv("LLAMA3_1_8B_PATH", "meta-llama/Llama-3.1-8B-Instruct")}

ALIASES = {
    "qwen-3-4b": "qwen3-4b",
    "qwen3_4b": "qwen3-4b",
    "llama-3.1-8b": "llama3.1-8b",
    "llama3_1-8b": "llama3.1-8b",
}


def load_embed_model(model_name: str = "Qwen/Qwen3-Embedding-0.6B"):
    model = SentenceTransformer(model_name)

    return model


def resolve_model_id(name: str) -> str:

    key = name.strip().lower()
    key = ALIASES.get(key, key)
    try:
        return MODEL_ID_MAP[key]
    except KeyError as e:
        raise KeyError(
            f"Unknown model short name '{name}'. "
            f"Known: {sorted(MODEL_ID_MAP.keys())}"
        ) from e


def _infer_provider(model: str) -> str:
    m = model.lower()
    if any(k in m for k in ["llama", "qwen", "mistral",  "gemma","deepseek", "mixtral"]):
        return "transformers"
    return model

def generate(model: str, prompt: str, temperature: float = 0, max_new_tokens: int=256) -> str:
    provider = _infer_provider(model)

    if provider == "openai":
        if not _HAS_OPENAI:
            raise RuntimeError("openai not installed. pip install openai")
        client = OpenAI(
            api_key=OPENAI_API_KEY,
        )

        max_retries = 5
        retry_delay = 2

        for attempt in range(max_retries):
            try:

                resp = client.responses.create(
                    model='gpt-5.2',
                    input=prompt,
                    max_output_tokens=max_new_tokens,
                    reasoning={"effort": "none"}
                )
                return (resp.output_text).strip()

            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Has Error (try {attempt + 1}/{max_retries}): {e}")
                    print(f"will retry in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2
                else:
                    print(f"reach max number of retry: {e}")
        return ""




    if provider == "gemini":
        if not _HAS_GEMINI:
            raise RuntimeError("gemini not installed. pip install -q -U google-genai")

        client = genai.Client(api_key=GEMINI_API_KEY)


        safety_settings = [
            types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
            types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
            types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
            types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
        ]

        max_retries = 5
        retry_delay = 2

        for attempt in range(max_retries):
            try:

                response = client.models.generate_content(
                    model="gemini-2.5-flash",
                    contents=prompt,
                    config=types.GenerateContentConfig(
                        max_output_tokens=max_new_tokens,
                        thinking_config=types.ThinkingConfig(thinking_budget=0),
                        safety_settings=safety_settings,)
                )
                return response.text.strip()

            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Has Error (try {attempt + 1}/{max_retries}): {e}")
                    print(f"will retry in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2
                else:
                    print(f"reach max number of retry: {e}")
        return ""


    if not _HAS_TRANSFORMERS:
        raise RuntimeError("transformers/torch not installed. pip install transformers accelerate torch")

    if _HAS_VLLM:
        key = f"{model}|vllm"
        tok, llm = _CACHE.get(key, (None, None))

        if tok is None:
            tok = AutoTokenizer.from_pretrained(
                model,
                trust_remote_code=True,
                cache_dir=_CACHE_DIR,
            )
            if tok.pad_token_id is None:
                tok.pad_token = tok.eos_token


            llm = LLM(
                model=LOCAL_LLM_PATH[model],
                # dtype="bfloat16",
                max_model_len=8192,
                tensor_parallel_size=1,
                gpu_memory_utilization=0.4,
            )
            _CACHE[key] = (tok, llm)

        if hasattr(tok, "apply_chat_template"):
            chat_text = tok.apply_chat_template(
                [{"role": "user", "content": prompt}],
                tokenize=False,
                add_generation_prompt=True,
            )
        else:
            chat_text = prompt

        sampling_params = SamplingParams(
            max_tokens=max_new_tokens,
            temperature=float(temperature),
            n=1,
            stop_token_ids=[tok.eos_token_id] if tok.eos_token_id is not None else None,
        )

        outputs = llm.generate([chat_text], sampling_params)
        text = outputs[0].outputs[0].text.strip()
        return text

    # -------------------- return HF --------------------
    key = f"{model}|bf16"
    tok, mdl = _CACHE.get(key, (None, None))
    if tok is None:
        tok = AutoTokenizer.from_pretrained(model, trust_remote_code=True, cache_dir=_CACHE_DIR)
        mdl = AutoModelForCausalLM.from_pretrained(
            model,
            dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
            cache_dir=_CACHE_DIR,
            attn_implementation="flash_attention_2",
        )
        mdl.eval()
        _CACHE[key] = (tok, mdl)
    else:
        mdl.eval()

    if hasattr(tok, "apply_chat_template"):
        chat_text = tok.apply_chat_template(
            [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True
        )
    else:
        chat_text = prompt

    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    inputs = tok(chat_text, return_tensors="pt").to(mdl.device)
    out_ids = mdl.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=(temperature > 0),
        temperature=float(temperature),
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    text = tok.decode(out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    return text