"""
The fast api implementation of LLM
Using VLLM framework to adapt for multiple mainstream LLM architectures
"""
from fastapi import FastAPI, Request
import argparse
import json
import torch
import time
import uuid
import sys
import os
import uvicorn
import subprocess
import re
import numpy as np
import torch.nn.functional as F
from global_utils.reward_models import auto_get_rm, rm_path_dict
import asyncio
from typing import List, Dict, Any
from fastapi import FastAPI
from uvicorn.config import LOGGING_CONFIG
import time
from termcolor import colored
import logging

LOGGING_CONFIG["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelprefix)s %(message)s"
LOGGING_CONFIG["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelprefix)s %(client_addr)s - \"%(request_line)s\" %(status_code)s"


logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler1 = logging.StreamHandler()
formatter = logging.Formatter(
    "%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s"
)
logger.addHandler(handler1)

# os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

pending: List[Dict[str, Any]] = []  # {"data": payload, "future": Future}
lock = asyncio.Lock()
FIRST_REQUEST_TS: float | None = None  # monotonic time of first item in queue
CHECK_INTERVAL = 0.05   # how often batch_worker wakes up (seconds)
TIMEOUT_SECONDS = 2.0   # max wait time for the first request in batch

# Command line argument parsing
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default='Qwen2.5-Math-PRM-7B')
parser.add_argument("--port", default=6006)
parser.add_argument("--batch_size", type=int, default=8,
                    help="Number of requests to batch together")

args = parser.parse_args()
gpu_num = torch.cuda.device_count()
# Create the FastAPI application
app = FastAPI()

# GPU cleanup function
def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def get_response(data_batch):

    # pid=str(os.getpid())
    flattened_question_list = []
    flattened_response_list = []
    sub_list_lengths = []
    # print(data_batch)
    # st=time.time()
    for sub_list in data_batch:
        flattened_question_list.extend(sub_list['question'])  
        flattened_response_list.extend(sub_list['response'])  
        sub_list_lengths.append(len(sub_list['question']))  
    # logger.info(colored(f'------------------PID: {pid}--Create Data list time cost: '+str(time.time()-st),'blue'))
    # logger.info('test!!!!')
    # st=time.time()
    # reward=args.model.obtain_reward(flattened_question_list,flattened_response_list,int(sum(sub_list_lengths)/len(sub_list_lengths)))
    reward=args.model.obtain_reward(flattened_question_list,flattened_response_list,1)
    # reward=args.model.obtain_reward(flattened_question_list,flattened_response_list,int(sum(sub_list_lengths)))
    # print(colored(f'------------------PID: {pid}--Get reward cost: '+str(time.time()-st),'yellow'))
    
    restored_list = []
    start = 0
    # print(sub_list_lengths)
    for length in sub_list_lengths:
        end = start + length
        restored_list.append(reward[start:end])
        start = end  # 更新起始位置
    # print(restored_list)
    return restored_list
    



async def batch_worker():
    """Background coroutine that delivers batched inference."""
    global FIRST_REQUEST_TS
    loop = asyncio.get_event_loop()
    while True:
        await asyncio.sleep(CHECK_INTERVAL)
        async with lock:
            if not pending:
                
                FIRST_REQUEST_TS = None  # queue drained – reset timer
                continue

            queue_age = loop.time() - (FIRST_REQUEST_TS or loop.time())
            ready_by_size = len(pending) >= args.batch_size
            ready_by_time = queue_age >= TIMEOUT_SECONDS
            if not (ready_by_size or ready_by_time):
                continue  # still collecting
            
            # Determine slice to serve this round
            if ready_by_size:
                batch = pending[:args.batch_size]
                del pending[:args.batch_size]
            else:  # timeout: flush everything accumulated so far
                print(colored('Timeout pending size: '+str(len(pending)),'red'))
                batch = list(pending)
                pending.clear()
            if not pending:
                FIRST_REQUEST_TS = None  # reset for the next cycle
        # ---------------- queue section end (lock released) -----------------
        try:
            # reward=args.model.obtain_reward(json_post_list['question'], json_post_list['response'], json_post_list['batch_size'])
            reward=get_response([item['data'] for item in batch])
            for item, ans in zip(batch, reward):
                item["future"].set_result(ans)
                # print(item)
        except Exception as exc:
            for item in batch:
                item["future"].set_exception(exc)
        finally:
            torch_gc()



@app.on_event("startup")
async def _startup():
    asyncio.create_task(batch_worker())
    print(f"Server {args.server_name} (batch={args.batch_size}) ready on port {args.port}")


# The main function for POST request
@app.post("/")
async def create_item(request: Request):
    json_post_raw = await request.json()

    loop = asyncio.get_event_loop()
    fut: asyncio.Future = loop.create_future()

    async with lock:
        global FIRST_REQUEST_TS
        pending.append({"data": json_post_raw, "future": fut})
        if len(pending) == 1:
            FIRST_REQUEST_TS = loop.time()  # start timeout window
        # If we just hit batch size, let worker notice quickly (no-op here)
    # print(fut.result())

    return await fut


# Main server function
def main():
    model = auto_get_rm(args.model_name)(rm_path_dict[args.model_name], args.model_name, device='auto')
    args.model = model
    model_name = args.model_name
    server_name = model_name + '-' + str(uuid.uuid4()).split('-')[0]
    args.server_name = server_name
    print(f'Server {server_name} started and waiting for requests!')
    ip_output = subprocess.run(['ip', 'addr'], capture_output=True, text=True).stdout
    # match = re.search(r'net\d+:\s+.*?inet\s+(10\.\d+\.\d+\.\d+)', ip_output, re.DOTALL)
    match = re.search(r'inet\s+(172\.\d+\.\d+\.\d+)', ip_output)
    if match:
        internal_ip = match.group(1)
        print("The id address is:", internal_ip)
        with open('./all_server.txt', 'a') as f:
            f.write(f"server_name: {server_name} ip: {internal_ip} \n")
    else:
        print("No match ip address!")


    # Start FastAPI
    # 6006 endpoint
    uvicorn.run(app, host='0.0.0.0', port=int(args.port), workers=1 )  


if __name__ == '__main__':
    main()






