from fastapi import FastAPI, Header
from pydantic import BaseModel
from typing import Optional, List
import torch
import torch.nn.functional as F
import uvicorn
import time
import math

from src.models.language_model import load_model_and_tokenizer

# -----------------------------
# Initialize FastAPI and Hugging Face model
# -----------------------------
app = FastAPI(title="Local OpenAI /v1/responses Clone")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

hf_model_name = "meta-llama/Llama-3.2-3B-Instruct"
local_model_dir = f"./local/models/{hf_model_name}"
model, tokenizer = load_model_and_tokenizer(local_model_dir, device=device, local_files_only=True)

def sanitize_float(x):
    if math.isinf(x) or math.isnan(x):
        return -100.0  # or any very small number
    return float(x)

# -----------------------------
# Request & Response Schemas
# -----------------------------
class ResponseRequest(BaseModel):
    model: str
    input: str
    max_output_tokens: Optional[int] = 5
    metadata: Optional[dict] = None
    temperature: Optional[float] = None
    n: Optional[int] = 1  # number of completions / choices

class ContentItem(BaseModel):
    type: str
    text: str

class ResponseChoice(BaseModel):
    id: str
    type: str
    index: int
    content: List[ContentItem]
    logprobs: Optional[List] = None
    finish_reason: Optional[str] = "stop"

class ResponseResponse(BaseModel):
    id: str
    object: str
    created: int
    model: str
    choices: List[ResponseChoice]

# -----------------------------
# Helper Function: Generate tokens & logprobs
# -----------------------------
def generate_completion(prompt: str, max_tokens: int, top_k: int, top_logprobs: int, temperature: float, return_logprobs: bool):
    input_ids = tokenizer(prompt, return_tensors="pt").to(device).input_ids

    gen_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": max_tokens,
        "pad_token_id": tokenizer.eos_token_id,
        "return_dict_in_generate": True,
        "output_scores": return_logprobs,
        "do_sample": temperature > 0,
        "temperature": temperature if temperature > 0 else 1.0,
        "top_k": top_k if temperature > 0 else None
    }

    with torch.no_grad():
        outputs = model.generate(**gen_kwargs)

    # Extract generated tokens (skip prompt)
    generated_ids = outputs.sequences[0, input_ids.shape[1]:]

    # Full text
    output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

    # Logprobs
    top_logprobs_list = []
    if return_logprobs and getattr(outputs, "scores", None):
        for scores in outputs.scores:
            log_probs = F.log_softmax(scores, dim=-1)
            topk_logprobs, topk_indices = torch.topk(log_probs, top_logprobs, dim=-1)
            
            top_logprobs_list.append([
                {
                    "token_id": int(idx),
                    "token": tokenizer.convert_ids_to_tokens(idx),
                    "logprob": sanitize_float(prob)
                }
                for idx, prob in zip(topk_indices[0].tolist(), topk_logprobs[0].tolist())
            ])
                                    
            
    return output_text, top_logprobs_list

# -----------------------------
# API Endpoint
# -----------------------------
@app.post("/v1/responses", response_model=ResponseResponse)
def responses(req: ResponseRequest, authorization: Optional[str] = Header(None)):
    metadata = req.metadata or {}
    top_k = metadata.get("top_k", 20)
    top_logprobs = metadata.get("top_logprobs", top_k)
    max_tokens = req.max_output_tokens or 5
    temperature = req.temperature if req.temperature is not None else getattr(model.generation_config, "temperature", 1.0)
    if temperature is None:
        temperature = 1.0
    num_choices = req.n or 1
    return_logprobs = metadata.get("logprobs", False)

    choices = []
    for i in range(num_choices):
        output_text, logprobs_obj = generate_completion(
            prompt=req.input,
            max_tokens=max_tokens,
            top_k=top_k,
            top_logprobs=top_logprobs,
            temperature=temperature,
            return_logprobs=return_logprobs
        )

        content_item = ContentItem(type="output_text", text=output_text)
        choice = ResponseChoice(
            id=str(i),
            type="message",
            index=i,
            content=[content_item],
            logprobs=logprobs_obj,
            finish_reason="stop"
        )
        choices.append(choice)

    return ResponseResponse(
        id="local-1",
        object="response",
        created=int(time.time()),
        model=req.model,
        choices=choices
    )

# -----------------------------
# Run server
# -----------------------------
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
