import numpy as np
import torch
import gc
import psutil
import os

def log_memory(tag=""):
    """记录当前内存使用情况"""
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"[{tag}] Memory usage: {mem_info.rss / (1024 ** 3):.2f} GB")

def safe_to_numpy(tensor):
    """安全地将张量转换为numpy数组"""
    if tensor is None:
        return None
    if isinstance(tensor, torch.Tensor):
        with torch.no_grad():
            return tensor.detach().cpu().numpy()
    return tensor

def metric_numpy_optimized(pred, true, null_val=0.0):
    """优化的numpy指标计算"""
    # 确保输入为numpy数组
    pred_np = safe_to_numpy(pred)
    true_np = safe_to_numpy(true)
    
    # 创建掩码 - 简化逻辑
    if np.isnan(null_val):
        mask = ~np.isnan(true_np)
    else:
        mask = (true_np != null_val) & ~np.isnan(true_np) & ~np.isnan(pred_np)
    
    if np.sum(mask) == 0:
        raise ValueError("All labels are invalid based on the mask.")
    
    # 使用掩码提取有效数据 - 避免重复的norm_mask计算
    pred_valid = pred_np[mask]
    true_valid = true_np[mask]
    
    # 计算基本误差
    diff = pred_valid - true_valid
    abs_diff = np.abs(diff)
    
    # MAE
    mae = np.mean(abs_diff)
    
    # MSE
    mse = np.mean(diff ** 2)
    
    # RMSE
    rmse = np.sqrt(mse)
    
    # MAPE - 简化处理
    # 避免除零：使用一个小的epsilon值
    eps = 1e-8
    true_safe = np.where(np.abs(true_valid) < eps, eps, true_valid)
    mape_values = np.abs(diff / true_safe)
    # 限制极值影响
    mape_values = np.clip(mape_values, 0, 5.0)
    mape = np.mean(mape_values)
    
    # MSPE - 简化处理
    mspe_values = (diff / true_safe) ** 2
    mspe_values = np.clip(mspe_values, 0, 25.0)
    mspe = np.mean(mspe_values)
    
    # Accuracy - 简化计算
    acc = 1 - np.mean(abs_diff) / (np.mean(np.abs(true_valid)) + eps)
    
    return mae, mse, rmse, mape, mspe, acc

def metric_fast(pred, real, null_val=0.0):
    """快速版本的指标计算，避免不必要的批处理"""
    
    # 转换为numpy
    pred_np = safe_to_numpy(pred)
    true_np = safe_to_numpy(real)
    
    # 检查数据大小，决定是否需要特殊处理
    total_elements = pred_np.size
    
    # 如果数据量很大（>100M个元素），才考虑分批
    if total_elements > 100_000_000:
        return batch_metric_optimized(pred_np, true_np, null_val=null_val)
    else:
        # 直接计算，避免不必要的批处理开销
        return metric_numpy_optimized(pred_np, true_np, null_val=null_val)

def batch_metric_optimized(pred_np, true_np, batch_size=50000, null_val=0.0):
    """优化的批处理指标计算"""
    
    # 将数据展平以便批处理
    original_shape = pred_np.shape
    pred_flat = pred_np.flatten()
    true_flat = true_np.flatten()
    
    total_samples = len(pred_flat)
    num_batches = (total_samples + batch_size - 1) // batch_size
    
    # 累积器
    results = []
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, total_samples)
        
        batch_pred = pred_flat[start_idx:end_idx]
        batch_true = true_flat[start_idx:end_idx]
        
        # 计算当前批次的指标
        batch_results = metric_numpy_optimized(batch_pred, batch_true, null_val)
        results.append(batch_results)
        
        # 定期清理内存
        if i % 20 == 0:
            gc.collect()
    
    # 计算加权平均（简化）
    results_array = np.array(results)
    final_results = np.mean(results_array, axis=0)
    
    return tuple(final_results)

def metric(pred, real, batch_size=None):
    """主要接口函数 - 大幅简化"""
    
    # 只在必要时记录内存
    # log_memory("Initial")
    
    # 减少垃圾回收频率
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # 使用快速版本
    result = metric_fast(pred, real, null_val=0.0)
    
    # 只在最后清理一次
    gc.collect()
    
    return result

# 如果你想要更激进的优化，可以使用这个最简版本
def metric_simple(pred, real):
    """最简化的指标计算版本"""
    pred_np = safe_to_numpy(pred)
    true_np = safe_to_numpy(real)
    
    # 创建掩码
    mask = ~(np.isnan(pred_np) | np.isnan(true_np) | (true_np == 0.0))
    
    if np.sum(mask) == 0:
        return 0, 0, 0, 0, 0, 0
    
    pred_valid = pred_np[mask]
    true_valid = true_np[mask]
    
    diff = pred_valid - true_valid
    abs_diff = np.abs(diff)
    
    mae = np.mean(abs_diff)
    mse = np.mean(diff ** 2)
    rmse = np.sqrt(mse)
    
    # 简化的MAPE和MSPE
    eps = 1e-8
    true_safe = np.maximum(np.abs(true_valid), eps)
    mape = np.mean(abs_diff / true_safe)
    mspe = np.mean((diff / true_safe) ** 2)
    
    acc = 1 - mae / (np.mean(np.abs(true_valid)) + eps)
    
    return mae, mse, rmse, mape, mspe, acc

# 示例使用
def main():
    """测试代码"""
    import time
    
    # 创建测试数据
    size = (1000, 32, 32)
    pred = np.random.randn(*size)
    true = np.random.randn(*size)
    
    # 测试原版本
    print("测试优化版本:")
    start_time = time.time()
    mae, mse, rmse, mape, mspe, acc = metric(pred, true)
    print(f"优化版本耗时: {time.time() - start_time:.4f}秒")
    print(f'MAE: {mae:.6f}, MSE: {mse:.6f}, RMSE: {rmse:.6f}')
    print(f'MAPE: {mape:.6f}, MSPE: {mspe:.6f}, ACC: {acc:.6f}')
    
    # 测试简化版本
    print("\n测试简化版本:")
    start_time = time.time()
    mae, mse, rmse, mape, mspe, acc = metric_simple(pred, true)
    print(f"简化版本耗时: {time.time() - start_time:.4f}秒")
    print(f'MAE: {mae:.6f}, MSE: {mse:.6f}, RMSE: {rmse:.6f}')
    print(f'MAPE: {mape:.6f}, MSPE: {mspe:.6f}, ACC: {acc:.6f}')

if __name__ == "__main__":
    main()