"""日志模块"""

import sys
import datetime
import os
import time
from typing import Optional
from pathlib import Path

LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "SUCCESS", "SPECIAL"]

console_level = "DEBUG"
file_level = "TRACE"

logs_path = Path(__file__).parent
logs_path = logs_path / "output" / "logs"
# print(logs_path)
log_file_path = None

COLORS = {
    "red": "\033[91m",
    "green": "\033[92m",
    "yellow": "\033[93m",
    "blue": "\033[94m",
    "magenta": "\033[95m",
    "cyan": "\033[96m",
    "white": "\033[97m",
    "grey": "\033[90m",
    "reset": "\033[0m",
}

# WebSocket 日志推送回调列表
_websocket_callbacks = []

def register_websocket_callback(callback):
    """注册一个回调函数，用于实时推送日志到 WebSocket"""
    _websocket_callbacks.append(callback)

def unregister_websocket_callback(callback):
    """注销回调"""
    if callback in _websocket_callbacks:
        _websocket_callbacks.remove(callback)


# ===== LLM Token 统计 =====
_token_stats = {
    "calls": 0,
    "prompt_tokens": 0,
    "completion_tokens": 0,
    "total_tokens": 0,
}

def reset_token_stats():
    """重置 LLM token 统计。"""
    _token_stats["calls"] = 0
    _token_stats["prompt_tokens"] = 0
    _token_stats["completion_tokens"] = 0
    _token_stats["total_tokens"] = 0

def add_token_usage(prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: Optional[int] = None):
    """累计一次 token 使用量。total_tokens 为空时按两者相加。"""
    p = int(prompt_tokens or 0)
    c = int(completion_tokens or 0)
    t = int(total_tokens) if total_tokens is not None else (p + c)

    _token_stats["calls"] += 1
    _token_stats["prompt_tokens"] += p
    _token_stats["completion_tokens"] += c
    _token_stats["total_tokens"] += t
    return p, c, t

def log_token_usage(model: str = "", prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: Optional[int] = None):
    """记录并打印一次调用的 token 消耗。"""
    p, c, t = add_token_usage(prompt_tokens, completion_tokens, total_tokens)
    trace(
        f"【LLM Token 本次】model={model or '-'} prompt={p} completion={c} total={t}+{token_summary_text()}"
    )
def token_summary_text() -> str:
    return (
        f"【LLM Token 累计】calls={_token_stats['calls']} "
        f"prompt={_token_stats['prompt_tokens']} completion={_token_stats['completion_tokens']} total={_token_stats['total_tokens']}"
    )

def log_token_summary():
    """打印累计 token 消耗摘要。"""
    info(token_summary_text())


def init(log_filename=None, console_log_level="DEBUG", file_log_level="TRACE"):
    """初始化日志模块"""
    global log_file_path, console_level, file_level

    os.makedirs(logs_path, exist_ok=True)

    if log_filename:
        log_file_path = os.path.join(logs_path, log_filename)
    else:
        date_str = time.strftime("%Y-%m-%d %H", time.localtime())
        log_file_path = os.path.join(logs_path, "log_" + date_str + ".log")
    open(log_file_path, "w", encoding="utf-8").close()
    if console_log_level in LEVELS:
        console_level = console_log_level

    if file_log_level in LEVELS:
        file_level = file_log_level

    # 初始化时重置 token 统计，避免跨进程/多次运行累积
    try:
        reset_token_stats()
    except Exception:
        pass


def should_log(level, target_level):
    """判断是否应该打印当前日志"""
    return LEVELS.index(level) >= LEVELS.index(target_level)


def color_print(level, color, *args, sep=" ", end="\n"):
    """通用日志打印函数，支持控制台和文件输出"""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
    log_message = f"{timestamp} [{level}] {sep.join(map(str, args))}"

    if should_log(level, console_level):
        print(
            f"{COLORS[color]}{log_message}{COLORS['reset']}", end=end, file=sys.stdout
        )

    if log_file_path and should_log(level, file_level):
        with open(log_file_path, "a", encoding="utf-8") as log_file:
            log_file.write(log_message + "\n")
    
    # 推送到所有注册的 WebSocket 回调
    for callback in _websocket_callbacks:
        try:
            callback(level, color, log_message)
        except Exception:
            pass  # 忽略回调异常，避免影响日志主流程


# 各级别日志函数
def trace(*args, sep=" ", end="\n"):
    """打印跟踪信息（白色）"""
    color_print("TRACE", "white", *args, sep=sep, end=end)


def debug(*args, sep=" ", end="\n"):
    """打印调试信息（灰色）"""
    color_print("DEBUG", "grey", *args, sep=sep, end=end)


def info(*args, sep=" ", end="\n"):
    """打印普通信息（蓝色）"""
    color_print("INFO", "blue", *args, sep=sep, end=end)


def warning(*args, sep=" ", end="\n"):
    """打印警告信息（黄色）"""
    color_print("WARNING", "yellow", *args, sep=sep, end=end)


def error(*args, sep=" ", end="\n"):
    """打印错误信息（红色）"""
    color_print("ERROR", "red", *args, sep=sep, end=end)


def success(*args, sep=" ", end="\n"):
    """打印成功信息（绿色）"""
    color_print("SUCCESS", "green", *args, sep=sep, end=end)


def special(*args, sep=" ", end="\n"):
    """打印特殊信息（青色）"""
    color_print("SPECIAL", "cyan", *args, sep=sep, end=end)
