from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv
import os, json, math, asyncio
import torch
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os

# Select visible GPU IDs (e.g., use GPUs 0,1,2,3)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

load_dotenv()
client_gpt = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=os.environ["OPENAI_BASE_URL"])
client_deepseek = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url=os.environ["DEEPSEEK_BASE_URL"])
client_gpt_async = AsyncOpenAI()
client_deepseek_async = AsyncOpenAI(
    api_key=os.environ["DEEPSEEK_API_KEY"],
    base_url=os.environ["DEEPSEEK_BASE_URL"]
)

# Local model path and dtype
QWEN_MODEL_PATH = os.environ.get("QWEN_MODEL_PATH", "./models/LLM-Research/Meta-Llama-3-8B-Instruct")
QWEN_DTYPE = os.environ.get("QWEN_DTYPE", "float16")  # "auto" | "float16" | "bfloat16" | "float32"

# Auto-select dtype
if QWEN_DTYPE == "auto":
    dtype = "auto"
elif QWEN_DTYPE == "float16":
    dtype = torch.float16
elif QWEN_DTYPE == "bfloat16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

# Use accelerate's device_map="auto" to spread across multiple GPUs
print(">>> Loading Qwen with accelerate (multi-GPU support)...")
tokenizer_qwen = AutoTokenizer.from_pretrained(QWEN_MODEL_PATH, trust_remote_code=True)

model_qwen = AutoModelForCausalLM.from_pretrained(
    QWEN_MODEL_PATH,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True
)

model_qwen.eval()

# Global defaults for sampling, seeds, etc.
DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 1.0
DEFAULT_TOP_K = 1
DEFAULT_SEED = 42


def get_openai_completion_with_token_probs(
    prompt: str,
    history: list = None,
    model: str = "gpt-4.1-mini",
    logprobs=True,
    top_logprobs: int = 0,  # >0 to return per-step top-k candidates
    temperature=0,
):
    """
    Non-streaming call. Returns the full text and per-token probability info.

    Returns:
      {
        "text": <full generated text>,
        "tokens": [  # per token
          {
            "token": str,
            "logprob": float,
            "prob": float,             # exp(logprob)
            "top_logprobs": [          # optional
              {"token": str, "logprob": float, "prob": float}, ...
            ] or None
          }, ...
        ]
      }
    """
    if history is None:
        history = []
    elif isinstance(history, str):
        try:
            history = json.loads(history)
        except json.JSONDecodeError:
            history = []

    messages = history + [{"role": "user", "content": prompt}]

    # Key options: non-streaming; enable logprobs / top_logprobs
    response = client_gpt.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        stream=False,
        logprobs=logprobs,
        top_logprobs=top_logprobs if top_logprobs and top_logprobs > 0 else None,
        seed=DEFAULT_SEED,
    )

    choice = response.choices[0]
    full_text = choice.message.content or ""

    # Parse per-token probabilities
    tokens_info = []
    # In the new SDK: choice.logprobs.content is a list with token/logprob/top_logprobs
    lp_content = getattr(choice, "logprobs", None)
    lp_content = getattr(lp_content, "content", None) if lp_content else None

    if lp_content:
        for step in lp_content:
            tok = step.token
            lp = step.logprob
            top_list = []
            if step.top_logprobs:
                for cand in step.top_logprobs:
                    top_list.append({
                        "token": cand.token,
                        "logprob": cand.logprob,
                        "prob": math.exp(cand.logprob)
                    })
            tokens_info.append({
                "token": tok,
                "logprob": lp,
                "prob": math.exp(lp),
                "top_logprobs": top_list if top_list else None
            })

    return {
        "text": full_text,
        "tokens": tokens_info
    }


async def get_openai_completion_with_token_probs_async(
    prompt: str,
    history: Optional[list] = None,
    model: str = "gpt-4.1-mini",
    logprobs: bool = True,
    top_logprobs: int = 0,     # >0 to return per-step top-k candidates
    temperature: float = 0.0,
    stop: Optional[List[str]] = None,
    max_tokens: Optional[int] = 1000,
) -> Dict[str, Any]:
    """
    Async variant (non-streaming). Returns the full text and per-token probability info.
    Same return format as the sync version:
      {
        "text": str,
        "tokens": [
          {"token": str, "logprob": float, "prob": float, "top_logprobs": [ ... ] or None},
          ...
        ]
      }
    """
    if history is None:
        history = []
    elif isinstance(history, str):
        try:
            history = json.loads(history)
        except json.JSONDecodeError:
            history = []

    messages = history + [{"role": "user", "content": prompt}]

    resp = await client_gpt_async.chat.completions.create(
        model=model,
        messages=messages,
        stream=False,
        logprobs=logprobs,
        top_logprobs=top_logprobs if (logprobs and top_logprobs and top_logprobs > 0) else None,
        seed=DEFAULT_SEED,
    )

    choice = resp.choices[0]
    full_text = choice.message.content or ""

    tokens_info = []
    lp_content = getattr(choice, "logprobs", None)
    lp_content = getattr(lp_content, "content", None) if lp_content else None

    if lp_content:
        for step in lp_content:
            tok = step.token
            lp = step.logprob
            top_list = None
            if getattr(step, "top_logprobs", None):
                top_list = [
                    {
                        "token": cand.token,
                        "logprob": cand.logprob,
                        "prob": math.exp(cand.logprob),
                    }
                    for cand in step.top_logprobs
                ]
            tokens_info.append({
                "token": tok,
                "logprob": lp,
                "prob": math.exp(lp),
                "top_logprobs": top_list,
            })

    return {"text": full_text, "tokens": tokens_info}


def _apply_chat_template_qwen(messages: List[Dict[str, str]]) -> str:
    """
    Convert a {role, content} list into a Qwen chat prompt string.
    Prefer the tokenizer's chat_template; otherwise fall back to a simple join.
    """
    try:
        return tokenizer_qwen.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        # Simple fallback: concatenate by role
        parts = []
        for m in messages:
            r = m.get("role", "user")
            c = m.get("content", "")
            parts.append(f"{r.upper()}: {c}")
        parts.append("ASSISTANT:")
        return "\n".join(parts)


def _cut_by_stops(text: str, stop: Optional[List[str]]) -> str:
    """Cut the decoded text at the earliest occurrence of any stop string."""
    if not stop:
        return text
    cut_pos = len(text)
    for s in stop:
        if not s:
            continue
        idx = text.find(s)
        if idx != -1:
            cut_pos = min(cut_pos, idx)
    return text[:cut_pos]


def get_qwen3_local_completion_with_token_probs(
    prompt: str,
    history: Optional[list] = None,
    model: Optional[str] = None,          # Kept for signature compatibility; ignored for local inference
    logprobs: bool = True,
    top_logprobs: int = 0,                # >0 to return per-step top-k candidates
    temperature: float = 0.0,
    stop: Optional[List[str]] = None,
    max_tokens: int = 2048,
) -> Dict[str, Any]:
    """
    Local Qwen3 synchronous inference (non-streaming). Returns:
      {
        "text": str,
        "tokens": [{"token": str, "logprob": float, "prob": float, "top_logprobs": [...] or None}, ...]
      }
    """
    if history is None:
        history = []
    elif isinstance(history, str):
        try:
            history = json.loads(history)
        except json.JSONDecodeError:
            history = []

    messages = history + [{"role": "user", "content": prompt}]
    prompt_text = _apply_chat_template_qwen(messages)

    inputs = tokenizer_qwen(prompt_text, return_tensors="pt")
    # accelerate with device_map="auto" already places tensors on the right devices
    inputs = {k: v.to(model_qwen.device) for k, v in inputs.items()}
    input_len = inputs["input_ids"].shape[-1]

    do_sample = temperature > 0.0
    gen_out = model_qwen.generate(
        **inputs,
        do_sample=do_sample,
        temperature=max(temperature, 1e-6) if do_sample else None,
        # eos_token_id=tokenizer_qwen.eos_token_id,
        # pad_token_id=tokenizer_qwen.eos_token_id,
        output_scores=True,                 # Crucial: capture per-step logits
        return_dict_in_generate=True,
        max_new_tokens=max_tokens
    )

    # Only the newly generated portion
    seq = gen_out.sequences[0]
    gen_ids = seq[input_len:]
    decoded = tokenizer_qwen.decode(gen_ids, skip_special_tokens=True)
    decoded = _cut_by_stops(decoded, stop)

    tokens_info: List[Dict[str, Any]] = []
    if logprobs and hasattr(gen_out, "scores") and gen_out.scores:
        # gen_out.scores: list of length = #generated tokens; each tensor is [batch, vocab_size]
        # One score per generated token
        for i, logits in enumerate(gen_out.scores):
            # Align with stop truncation: if the text has been cut by stop, trim token stats accordingly
            if i >= len(gen_ids):
                break
            step_id = gen_ids[i].item()
            logp = F.log_softmax(logits[0], dim=-1)

            lp_val = float(logp[step_id].item())
            top_list = None
            if top_logprobs and top_logprobs > 0:
                vals, idxs = torch.topk(logp, k=min(top_logprobs, logp.shape[-1]))
                vals = vals.tolist()
                idxs = idxs.tolist()
                top_list = [
                    {
                        "token": tokenizer_qwen.convert_ids_to_tokens(int(tid)),
                        "logprob": float(v),
                        "prob": math.exp(float(v)),
                    }
                    for v, tid in zip(vals, idxs)
                ]

            tokens_info.append(
                {
                    "token": tokenizer_qwen.convert_ids_to_tokens(step_id),
                    "logprob": lp_val,
                    "prob": math.exp(lp_val),
                    "top_logprobs": top_list,
                }
            )

        # If truncated by stops, heuristically trim trailing tokens to match the decoded string
        if stop:
            rebuilt = ""
            trimmed = []
            for t in tokens_info:
                rebuilt += tokenizer_qwen.convert_tokens_to_string([t["token"]])
                trimmed.append(t)
                if len(rebuilt) >= len(decoded):
                    break
            tokens_info = trimmed

    return {"text": decoded, "tokens": tokens_info}


async def get_qwen3_local_completion_with_token_probs_async(
    prompt: str,
    history: Optional[list] = None,
    model: Optional[str] = None,
    logprobs: bool = True,
    top_logprobs: int = 0,
    temperature: float = 0.0,
    stop: Optional[List[str]] = None,
    max_tokens: int = 2048,
) -> Dict[str, Any]:
    """
    Async wrapper: call the sync local function in a thread pool to avoid blocking the event loop.
    """
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        None,
        lambda: get_qwen3_local_completion_with_token_probs(
            prompt=prompt,
            history=history,
            model=model,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            temperature=temperature,
            stop=stop,
            max_tokens=max_tokens
        ),
    )
