from datetime import datetime

import torch
import logging


def config_logging(log_file):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(fmt='%(message)s')

    if not logger.handlers:
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)
    else:
        logger.handlers[0].setFormatter(formatter)

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    return logger


def log_memory_usage(prefix: str, device: torch.device,logger):
    """记录显存使用情况"""
    if device.type == 'cuda':
        allocated = torch.cuda.memory_allocated(device) / 1024 ** 2
        cached = torch.cuda.memory_reserved(device) / 1024 ** 2
        max_allocated = torch.cuda.max_memory_allocated(device) / 1024 ** 2
        logger.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {prefix} - "
                    f"Allocated: {allocated:.2f}MB, "
                    f"Cached: {cached:.2f}MB, "
                    f"Max Allocated: {max_allocated:.2f}MB")


def log_kv_cache_size(past_key_values,logger):
    """计算KV cache的总大小（MB）"""
    if past_key_values is None:
        return 0

    total_size = 0
    for key_cache, value_cache in zip(past_key_values.key_cache, past_key_values.value_cache):
        key_size = key_cache.element_size() * key_cache.nelement()
        value_size = value_cache.element_size() * value_cache.nelement()
        total_size += key_size + value_size

    logger.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] KV cache size: {total_size / 1024 / 1024:.2f}MB")
