import random
import numpy as np
import torch
import logging
import sys


def set_random_seed(seed: int):
    # 1. 固定 Python random 模块的种子
    random.seed(seed)

    # 2. 固定 NumPy 的种子
    np.random.seed(seed)

    # 3. 固定 PyTorch CPU 和 GPU（如果有）上的随机种子
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 如果有多张GPU的话

    # 4. 为了确保训练过程中每次都能得到相同的结果，禁用cuDNN的自动优化（只在需要时）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # 关闭cuDNN优化算法，确保结果一致
    

def setup_logging():
    """配置日志系统，使其同时输出到控制台和文件。"""
    # 获取一个根logger
    logger = logging.getLogger("MoE_Analysis")
    logger.setLevel(logging.INFO)
    
    # 防止重复添加handler
    if logger.hasHandlers():
        logger.handlers.clear()

    # 创建一个handler，用于写入日志文件
    # mode='w' 表示每次运行都会覆盖旧的日志文件
    file_handler = logging.FileHandler("analysis_run.log", mode='w', encoding='utf-8')
    file_handler.setLevel(logging.INFO)

    # 创建一个handler，用于输出到控制台
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(logging.INFO)

    # 定义handler的输出格式
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    # 给logger添加handler
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    
    return logger


def calculate_compression_stats(original_model, compressed_model):
    """
    计算压缩统计信息

    Args:
        original_model: 原始模型
        compressed_model: 压缩后的模型

    Returns:
        dict: 压缩统计信息
    """
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())

    def get_model_size_mb(model):
        param_size = sum(p.numel() * p.element_size()
                            for p in model.parameters())
        buffer_size = sum(b.numel() * b.element_size()
                            for b in model.buffers())
        return (param_size + buffer_size) / (1024 * 1024)

    original_params = count_parameters(original_model)
    compressed_params = count_parameters(compressed_model)

    original_size_mb = get_model_size_mb(original_model)
    compressed_size_mb = get_model_size_mb(compressed_model)

    compression_ratio = compressed_params / original_params
    size_reduction = 1 - (compressed_size_mb / original_size_mb)

    stats = {
        "original_parameters": original_params,
        "compressed_parameters": compressed_params,
        "parameter_reduction": original_params - compressed_params,
        "compression_ratio": compression_ratio,
        "original_size_mb": original_size_mb,
        "compressed_size_mb": compressed_size_mb,
        "size_reduction": size_reduction,
        "size_reduction_mb": original_size_mb - compressed_size_mb
    }

    return stats
