import asyncio
import os
import threading
import time
from typing import Any, List, Optional

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests
import logging

# env var kept for compatibility with previous code
WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", "45"))

def build_logger(name: str, filename: Optional[str] = None) -> logging.Logger:
    logger = logging.getLogger(name)
    if logger.handlers:
        return logger
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    sh = logging.StreamHandler()
    sh.setFormatter(fmt)
    logger.addHandler(sh)
    if filename:
        # Use LOGDIR environment variable if set, otherwise use current directory
        logdir = os.environ.get("LOGDIR", "")
        if logdir:
            os.makedirs(logdir, exist_ok=True)
            filepath = os.path.join(logdir, filename)
        else:
            filepath = filename

        try:
            fh = logging.FileHandler(filepath)
            fh.setFormatter(fmt)
            logger.addHandler(fh)
        except Exception:
            logger.exception("Failed to create log file handler")
    return logger

def pretty_print_semaphore(sem: Any) -> str:
    try:
        value = getattr(sem, "_value", None)
        waiters = getattr(sem, "_waiters", None)
        waiter_cnt = 0 if waiters is None else len(waiters)
        return f"value={value},waiters={waiter_cnt}"
    except Exception:
        return str(sem)

def heart_beat_worker(obj):
    while True:
        time.sleep(WORKER_HEART_BEAT_INTERVAL)
        obj.send_heart_beat()

class Conversation:
    def __init__(self, template_name: Optional[str] = None):
        self.template_name = template_name or "default"
        self.sep_style = 0

worker = None
logger = None

app = FastAPI()

class BaseModelWorker:
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,
        multimodal: bool = False,
    ):
        global logger, worker

        self.controller_addr = controller_addr
        self.worker_addr = worker_addr
        self.worker_id = worker_id
        if model_path.endswith("/"):
            model_path = model_path[:-1]
        self.model_path = model_path
        self.model_names = model_names or [model_path]
        self.limit_worker_concurrency = limit_worker_concurrency
        self.conv = self.make_conv_template(conv_template, model_path)
        self.conv.sep_style = int(self.conv.sep_style)
        self.multimodal = multimodal
        self.tokenizer = None
        self.context_len = None
        self.call_ct = 0
        self.semaphore = None
        self.heart_beat_thread = None

        if logger is None:
            time_str = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))[2:]
            logger = build_logger("model_worker", f"llm_{time_str}.log")
        if worker is None:
            worker = self

    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
        return Conversation(conv_template or model_path)

    def init_heart_beat(self):
        self.register_to_controller()
        self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,), daemon=True)
        self.heart_beat_thread.start()

    def register_to_controller(self):
        logger.info("Register to controller")
        url = self.controller_addr + "/register_worker"
        data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status(), "multimodal": self.multimodal}
        r = requests.post(url, json=data)
        assert r.status_code == 200

    def send_heart_beat(self):
        logger.info(f"Send heart beat. Models: {self.model_names}. Semaphore: {pretty_print_semaphore(self.semaphore)}. "
                    f"call_ct: {self.call_ct}. worker_id: {self.worker_id}. ")

        url = self.controller_addr + "/receive_heart_beat"

        while True:
            try:
                ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
                exist = ret.json()["exist"]
                break
            except (requests.exceptions.RequestException, KeyError) as e:
                logger.error(f"heart beat error: {e}")
            time.sleep(10)

        if not exist:
            self.register_to_controller()

    def get_queue_length(self):
        if self.semaphore is None:
            return 0
        else:
            semaphore_value = self.semaphore._value if self.semaphore._value is not None else self.limit_worker_concurrency
            waiter_count = 0 if self.semaphore._waiters is None else len(self.semaphore._waiters)
            return self.limit_worker_concurrency - semaphore_value + waiter_count

    def get_status(self):
        return {"model_names": self.model_names, "speed": 1, "queue_length": self.get_queue_length()}

    def count_token(self, params):
        prompt = params["prompt"]

        try:
            input_ids = self.tokenizer(prompt).input_ids
            input_echo_len = len(input_ids)
        except TypeError:
            input_echo_len = self.tokenizer.num_tokens(prompt)

        ret = {"count": input_echo_len, "error_code": 0}
        return ret

    def get_conv_template(self):
        return {"conv": self.conv}

    def generate_stream_gate(self, params):
        raise NotImplementedError

    def generate_gate(self, params):
        raise NotImplementedError

    def get_embeddings(self, params):
        raise NotImplementedError

    def reward_inference_gate(self, params):
        raise NotImplementedError


def release_worker_semaphore():
    worker.semaphore.release()

def acquire_worker_semaphore():
    if worker.semaphore is None:
        worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
    return worker.semaphore.acquire()

def create_background_tasks():
    background_tasks = BackgroundTasks()
    background_tasks.add_task(release_worker_semaphore)
    return background_tasks



@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    generator = worker.generate_stream_gate(params)
    background_tasks = create_background_tasks()
    return StreamingResponse(generator, background=background_tasks)


@app.post("/worker_generate")
async def api_generate(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    output = await asyncio.to_thread(worker.generate_gate, params)
    release_worker_semaphore()
    return JSONResponse(output)


@app.post("/worker_reward_inference")
async def reward_inference(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    output = await asyncio.to_thread(worker.reward_inference_gate, params)
    release_worker_semaphore()
    return JSONResponse(output)


@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    embedding = worker.get_embeddings(params)
    release_worker_semaphore()
    return JSONResponse(content=embedding)


@app.post("/worker_get_status")
async def api_get_status(request: Request):
    return worker.get_status()


@app.post("/count_token")
async def api_count_token(request: Request):
    params = await request.json()
    return worker.count_token(params)


@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
    return worker.get_conv_template()


@app.post("/model_details")
async def api_model_details(request: Request):
    return {"context_length": worker.context_len}
