import argparse
import asyncio
import json
import time
import uuid
from fastapi import FastAPI, Request, Form, BackgroundTasks, UploadFile, File, HTTPException, status
from fastapi.responses import StreamingResponse, Response, JSONResponse

import os
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Any, List, Union, Optional
import uvicorn
from datetime import timedelta

import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
model_semaphore = None

import logging
from loguru import logger


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=80)
    parser.add_argument("--model-path", type=str, default="/tmp",
        help="The path to the weights")
    parser.add_argument("--model-name", type=str, default="LLaMA-3-8B-Chat")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--use_cache", action="store_true")
    return parser.parse_args()

@lru_cache(maxsize=1)
def get_model(model_path: str, device: str):
    """Get a whisper model from the cache or download it if it doesn't exist"""
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
    model.generation_config = GenerationConfig.from_pretrained(model_path)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
    return model, tokenizer


app = FastAPI()
args = get_args()
last_past_key_values = None
last_input_ids = []
model, tokenizer = get_model(args.model_path, args.device)




def release_model_semaphore():
    global model_semaphore
    model_semaphore.release()


def is_prefix(src, dst):
    if len(src) <= len(dst) or len(dst) == 0:
        return False
    for i in range(len(dst)):
        if src[i] != dst[i]:
            return False
    return True

def prefix_num(src, dst):
    for i in range(min(len(src), len(dst))-1):
        if src[i] == dst[i]:
            continue
        return i
    return min(len(src), len(dst))-1


def prepare_kv_cache(kv_cache, prefix_len):
    res = list()
    for obj in kv_cache:
        res.append((obj[0][:,:,:prefix_len,:].contiguous(),
                    obj[1][:,:,:prefix_len,:].contiguous()))
    return res


@torch.inference_mode
def generate(params):
    global last_input_ids, last_past_key_values
    logger.info(f"input params: {params}")
    req_id = str(uuid.uuid4())
    im_end_ids = model.generation_config.eos_token_id
    messages = params["messages"]
    stream = params.get("stream", False)
    prompt_ids = tokenizer.apply_chat_template(messages,
                                               add_generation_prompt=True)
    logger.info(f"after apply chat template prompt_ids: {prompt_ids}")
    logger.info(f"after apply chat template content: {tokenizer.decode(prompt_ids)}")
    if len(prompt_ids) > 8192:
        raise RuntimeError(f"Prompt Length Error: length {len(prompt_ids)} should be lower then 8192.")
    max_new_tokens = min(8192-len(prompt_ids), params.get("max_tokens", 256))
    hit_len = prefix_num(last_input_ids, prompt_ids)
    if args.use_cache and hit_len > 31:
        logger.info(f"hit cache: {hit_len}.")
        past_key_values = prepare_kv_cache(last_past_key_values, hit_len)
        input_ids = torch.as_tensor([prompt_ids[hit_len:]], device=args.device)
    else:
        logger.info(f"miss hit cache.")
        past_key_values = None
        input_ids = torch.as_tensor([prompt_ids], device=args.device)
    finish_reason = "None"
    output_token_ids = list()
    tokens = []
    out_tokens = []
    while len(output_token_ids) < max_new_tokens:
        if len(output_token_ids) == 0:
            out = model(input_ids, use_cache=True, past_key_values=past_key_values)
        else:
            out = model(input_ids=torch.as_tensor([tokens],
                                                  device=args.device),
                        use_cache=True,
                        past_key_values=past_key_values)
        logits = out.logits
        past_key_values = out.past_key_values
        last_token_logits = logits[:, -1, :][0]
        token = int(torch.argmax(last_token_logits))
        output_token_ids.append(token)
        tokens = [token]
        if tokens[-1] in im_end_ids:
            stopped = True
            finish_reason = "stop"
        else:
            stopped = False
        if len(output_token_ids) == max_new_tokens:
            finish_reason = "length"
            stopped = True
        if stream:
            if len(output_token_ids) == 1:
                ret = {
                    "id": req_id,
                    "object": "chat.completion.chunk",
                    "created": int(time.time()),
                    "model": args.model_name,
                    "system_fingerprint": None,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {
                                "role": "assistant",
                                "content": ""
                                },
                            "logprobs": None,
                            "finish_reason": finish_reason
                        }
                    ]
                }
                yield f"data: {json.dumps(ret, ensure_ascii=False)}\n\n"
            out_tokens.extend(tokens)
            content = str(tokenizer.decode(out_tokens, skip_special_tokens=True))
            if len(content) == 0 or content[-1] != "�":
                out_tokens = []
                ret = {
                    "id": req_id,
                    "object": "chat.completion.chunk",
                    "created": int(time.time()),
                    "model": args.model_name,
                    "system_fingerprint": None,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {
                                "content": content
                                },
                            "logprobs": None,
                            "finish_reason": finish_reason
                        }
                    ]
                }
                yield f"data: {json.dumps(ret, ensure_ascii=False)}\n\n"
        if stopped:
            break
    if args.use_cache:
        last_past_key_values = past_key_values
        last_input_ids = prompt_ids + output_token_ids
    logger.info(f"reply : {str(tokenizer.decode(output_token_ids, skip_special_tokens=True))}")
    if stream:
        yield "data: [DONE]\n\n"
    else:
        ret = {
            "id": req_id,
            "object": "chat.completion",
            "created": int(time.time()),
            "model": args.model_name,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": str(tokenizer.decode(output_token_ids, skip_special_tokens=True))
                    },
                    "logprobs": None,
                    "finish_reason": finish_reason
                }
            ],
            "usage": {
                "prompt_tokens": len(prompt_ids),
                "completion_tokens": len(output_token_ids),
                "total_tokens": len(prompt_ids)+len(output_token_ids)
            },
            "system_fingerprint": None
        }
        yield ret


@app.post('/v1/chat/completions')
async def completions(request: Request):
    params = await request.json()
    global model_semaphore
    if model_semaphore is None:
        model_semaphore = asyncio.Semaphore(1)
    await model_semaphore.acquire()
    generator = generate(params)
    background_tasks = BackgroundTasks()
    background_tasks.add_task(release_model_semaphore)
    if params.get("stream", False):
        return StreamingResponse(generator,
                                background=background_tasks,
                                media_type="text/event-stream")
    else:
        print(generator)
        return JSONResponse(list(generator)[0],
                            background=background_tasks)

if __name__ == "__main__":
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")

