import torch
import os
from tqdm import tqdm  # 添加进度条
import numpy as np
from torch.utils.data import DataLoader

def load_images(dataset='cifar100', max_tasks=10):
    """加载并合并所有任务的图像和标签"""
    all_images, all_labels = [], []
    
    for i in range(max_tasks):
        data_path = f'../../R-DFCIL-master/checkpoint/{dataset}_10_data_resume/data_{i}.pt'
        if not os.path.exists(data_path):
            print(f"找不到任务{i}的数据，停止加载")
            break
            
        try:
            image, label = torch.load(data_path)
            all_images.append(image)
            all_labels.append(label)  # 修复：原代码错误地用了labels.append(labels)
        except Exception as e:
            print(f"加载任务{i}数据时出错: {e}")
    
    # 合并所有数据
    return torch.cat(all_images) if all_images else None, torch.cat(all_labels) if all_labels else None

def calculate_oracle_protos_and_covs(model, device='cuda:7', current_task=9, dataset='cifar100', batch_size=500):
    """计算每个类别的原型向量、协方差矩阵及其秩，并可视化特征值分布"""
    print(f"开始计算{dataset}的原型向量和协方差矩阵...")
    model.eval()
    model.to(device)
    
    # 1. 加载数据
    all_images, all_labels = load_images(dataset=dataset, max_tasks=(current_task + 1))
    if all_images is None:
        print("没有找到数据，无法计算")
        return [], [], []
    
    print(f"加载了{len(all_images)}张图像, {len(torch.unique(all_labels))}个类别")
    
    # 2. 批量提取特征
    features, targets = [], []
    num_batches = (len(all_images) + batch_size - 1) // batch_size
    
    with torch.no_grad():
        for i in tqdm(range(num_batches), desc="提取特征"):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(all_images))
            
            batch_images = all_images[start_idx:end_idx].to(device)
            batch_labels = all_labels[start_idx:end_idx].to(device)
            
            batch_features = model(batch_images)["features"]
            
            features.append(batch_features)
            targets.append(batch_labels)
    
    # 3. 合并所有特征和标签
    features = torch.cat(features)
    targets = torch.cat(targets)
    
    # 4. 计算每个类的原型、协方差矩阵和矩阵秩
    protos = []
    covs = []
    ranks = []
    unique_labels = torch.unique(targets, sorted=True)
    
    # 存储每个类别的特征值（用于绘图）
    class_eigenvalues = {}
    
    for class_id in tqdm(unique_labels, desc="计算原型和协方差"):
        mask = targets == class_id
        class_features = features[mask]
        
        if len(class_features) < 2:  # 至少需要2个样本才能计算协方差
            print(f"警告: 类别{class_id}样本数量不足({len(class_features)})")
            continue
            
        # 计算原型（均值）
        proto = class_features.mean(dim=0)
        
        # 计算协方差矩阵
        # 中心化数据
        centered_features = class_features - proto.unsqueeze(0)
        # 计算协方差矩阵 (N-1是无偏估计)
        cov = torch.mm(centered_features.T, centered_features) / (len(class_features) - 1)
        
        # 计算协方差矩阵的特征值
        eigenvalues, _ = torch.linalg.eigh(cov)  # 使用eigh更稳定，返回升序特征值
        
        # 将特征值从大到小排序
        eigenvalues = eigenvalues.flip(0)  # 反转使其从大到小排序
        
        # 保存特征值用于后续绘图
        class_eigenvalues[class_id.item()] = eigenvalues.cpu().numpy()
        
        # 计算矩阵的秩（使用阈值筛选有效特征值）
        threshold = 1e-5 * torch.max(eigenvalues)
        rank = torch.sum(eigenvalues > threshold).item()
        
        protos.append(proto)
        covs.append(cov)
        ranks.append(rank)
        
        # 打印每个类的统计信息
        print(f"类别 {class_id}:")
        print(f"  - 样本数: {len(class_features)}")
        print(f"  - 特征维度: {proto.shape}")
        print(f"  - 协方差矩阵维度: {cov.shape}")
        print(f"  - 矩阵秩: {rank}")
        print(f"  - 条件数: {eigenvalues[0]/eigenvalues[-1] if rank > 0 else float('inf'):.2e}")
        print(f"  - 最大特征值: {eigenvalues[0]:.4e}")
        print(f"  - 最小特征值: {eigenvalues[-1]:.4e}")
    
    print(f"\n统计信息:")
    print(f"- 成功处理类别数: {len(protos)}")
    print(f"- 平均矩阵秩: {sum(ranks)/len(ranks):.2f}")
    print(f"- 最大矩阵秩: {max(ranks)}")
    print(f"- 最小矩阵秩: {min(ranks)}")
    
    # 5. 绘制特征值分布图
    if class_eigenvalues:
        try:
            # 调用通用特征值绘图函数
            plot_eigenvalues(
                eigenvalues_dict=class_eigenvalues,
                output_prefix=f"{dataset}_eigenvalues_task{current_task}",
                max_classes=len(unique_labels),
                max_top_eigenvalues=512,
                figsize_main=(15, 10),
                figsize_top=(12, 6),
                selected_classes=None,  # 自动选择类别
                plot_linear=True,
                plot_log=True,
                plot_cumulative=True,
                plot_condition=True,
                plot_top=True,
                dpi=300
            )
        except Exception as e:
            print(f"绘制特征值分布图时出错: {e}")
    
    return protos, covs, ranks

def calculate_syn_protos_covs(model, synthesizer, device='cuda:7'):
    """
    计算合成样本的类别原型、协方差矩阵及其秩
    
    参数:
        model: 提取特征的模型
        synthesizer: 样本合成器,需有sample方法
        device: 计算设备
        plot_eigenvalues: 是否绘制特征值分布图
    
    返回:
        protos: 类别原型列表
        covs: 协方差矩阵列表
        ranks: 矩阵秩列表
    """
    print(f"开始计算合成样本的原型和协方差...")
    model.eval()
    model.to(device)
    
    # 1. 生成并提取合成样本特征，直到每个类别都有至少500个样本
    all_features = []
    all_labels = []
    class_sample_counts = {}  # 跟踪每个类别的样本数量
    
    print("开始生成样本,每个类别需要500个样本...")
    with torch.no_grad():
        while True:
            # 生成合成样本
            images, labels = synthesizer.sample()  # 按照要求不传参数
            
            # 确保数据在正确设备上
            images = images.to(device)
            labels = labels.to(device)
            
            features = model(images)["features"]
            # 更新类别样本计数
            for label in torch.unique(labels):
                label_item = label.item()
                mask = labels == label
                count = mask.sum().item()
                
                if label_item not in class_sample_counts:
                    class_sample_counts[label_item] = 0
                class_sample_counts[label_item] += count
            
            all_features.append(features)
            all_labels.append(labels)
            
            # 检查是否所有类别都有至少500个样本
            min_samples = min(class_sample_counts.values()) if class_sample_counts else 0
            num_classes_with_enough = sum(1 for count in class_sample_counts.values() if count >= 500)
            
            print(f"已生成样本: {len(all_features)*len(images)}，类别覆盖: {len(class_sample_counts)},"
                  f"达到500+样本的类别: {num_classes_with_enough}，最少样本数: {min_samples}")
            
            if min_samples >= 500 and len(class_sample_counts) > 0:
                print("所有类别都已达到500个样本要求,停止生成")
                break
    
    # 合并所有特征和标签
    features = torch.cat(all_features)
    labels = torch.cat(all_labels)
    
    print(f"数据集特征提取完成，共 {len(features)} 个样本")
    
    # 2. 对每个类别随机选择恰好500个样本
    balanced_features = []
    balanced_labels = []
    
    unique_labels = torch.unique(labels)
    print(f"检测到{len(unique_labels)}个唯一类别，正在进行平衡采样...")
    
    for class_id in unique_labels:
        mask = labels == class_id
        class_features = features[mask]
        class_labels = labels[mask]
        
        # 如果样本数超过500，随机选择500个
        if len(class_features) > 500:
            indices = torch.randperm(len(class_features))[:500]
            class_features = class_features[indices]
            class_labels = class_labels[indices]
        # 如果样本数不足500，发出警告
        elif len(class_features) < 500:
            print(f"警告: 类别{class_id.item()}只有{len(class_features)}个样本,少于要求的500个")
        
        balanced_features.append(class_features)
        balanced_labels.append(class_labels)
    
    # 使用平衡后的数据
    features = torch.cat(balanced_features)
    labels = torch.cat(balanced_labels)
    
    print(f"平衡后样本总数: {len(features)}，预期数量: {len(unique_labels) * 500}")
    
    # 3. 计算每个类的原型、协方差矩阵和矩阵秩
    protos = []
    covs = []
    ranks = []
    class_counts = []
    condition_numbers = []
    
    # 存储每个类别的特征值（用于绘图）
    class_eigenvalues = {}
    
    for class_id in tqdm(unique_labels, desc="计算类别统计量"):
        mask = labels == class_id
        class_features = features[mask]
        class_count = len(class_features)
        class_counts.append(class_count)
        
        if class_count < 2:  # 需要至少2个样本
            print(f"警告: 类别{class_id.item()}样本数量不足({class_count})，跳过")
            continue
            
        # 计算原型（均值）
        proto = class_features.mean(dim=0)
        
        # 计算协方差矩阵
        # 中心化数据
        centered_features = class_features - proto.unsqueeze(0)
        # 计算协方差矩阵 (N-1是无偏估计)
        cov = torch.mm(centered_features.T, centered_features) / (class_count - 1)
        
        # 计算矩阵的特征值和秩
        try:
            # 使用eigh更稳定，返回升序特征值
            eigenvalues, _ = torch.linalg.eigh(cov)
            # 将特征值从大到小排序
            eigenvalues = eigenvalues.flip(0)
            
            # 保存特征值用于后续绘图
            class_eigenvalues[class_id.item()] = eigenvalues.cpu().numpy()
            
            # 计算矩阵的秩
            threshold = 1e-5 * torch.max(eigenvalues)
            rank = torch.sum(eigenvalues > threshold).item()
            
            # 计算条件数 (最大特征值/最小非零特征值)
            non_zero_vals = eigenvalues[eigenvalues > threshold]
            condition_number = (non_zero_vals[0] / non_zero_vals[-1]).item() if len(non_zero_vals) > 0 else float('inf')
            condition_numbers.append(condition_number)
            
        except Exception as e:
            print(f"计算类别{class_id.item()}的特征值分解时出错: {e}")
            rank = 0
            condition_number = float('inf')
        
        protos.append(proto)
        covs.append(cov)
        ranks.append(rank)
        
        # 打印每个类的详细信息
        print(f"类别 {class_id.item()}:")
        print(f"  - 样本数: {class_count}")
        print(f"  - 特征维度: {proto.shape}")
        print(f"  - 协方差矩阵维度: {cov.shape}")
        print(f"  - 矩阵秩: {rank}/{cov.shape[0]}")
        print(f"  - 条件数: {condition_number:.2e}")
        if len(eigenvalues) > 0:
            print(f"  - 最大特征值: {eigenvalues[0]:.4e}")
            print(f"  - 最小特征值: {eigenvalues[-1]:.4e}")
    
    # 4. 打印统计信息（不返回stats）
    avg_rank = sum(ranks) / len(ranks) if ranks else 0
    max_rank = max(ranks) if ranks else 0
    min_rank = min(ranks) if ranks else 0
    avg_condition = sum(condition_numbers) / len(condition_numbers) if condition_numbers else float('inf')
    
    print("\n===== 合成样本统计信息 =====")
    print(f"处理类别数: {len(protos)}")
    print(f"总样本数: {len(features)}")
    print(f"平均每类样本数: {len(features) / len(unique_labels):.1f}")
    print(f"平均矩阵秩: {avg_rank:.2f}")
    print(f"最大矩阵秩: {max_rank}")
    print(f"最小矩阵秩: {min_rank}")
    print(f"平均条件数: {avg_condition:.2e}")
    
    # 5. 绘制样本分布和矩阵秩统计图
    try:
        import matplotlib.pyplot as plt
        
        # 样本分布和秩分布图
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.bar(range(len(class_counts)), class_counts)
        plt.xlabel('Class Index')
        plt.ylabel('Sample Count')
        plt.title('Synthetic Sample Class Distribution')
        
        # 秩分布
        plt.subplot(1, 2, 2)
        plt.hist(ranks, bins=min(20, max(ranks) if ranks else 1), alpha=0.7)
        plt.axvline(avg_rank, color='r', linestyle='--', 
                   label=f"Mean Rank: {avg_rank:.2f}")
        plt.xlabel('Matrix Rank')
        plt.ylabel('Frequency')
        plt.title('Covariance Matrix Rank Distribution')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('syn_stats.png')
        print("统计图表已保存至 syn_stats.png")
    except Exception as e:
        print(f"基础统计可视化失败: {e}")
    
    # 6. 绘制特征值分布图（新增功能）
    if class_eigenvalues:
        try:
            # 调用通用特征值绘图函数
            plot_eigenvalues(
                eigenvalues_dict=class_eigenvalues,
                output_prefix="syn_eigenvalues",
                max_classes=len(unique_labels),
                max_top_eigenvalues=512,
                figsize_main=(15, 10),
                figsize_top=(12, 6),
                selected_classes=None,  # 自动选择类别
                plot_linear=True,
                plot_log=True,
                plot_cumulative=True,
                plot_condition=True,
                plot_top=True,
                dpi=300
            )
        except Exception as e:
            print(f"绘制特征值分布图时出错: {e}")
    
    return protos, covs, ranks

def calculate_real_protos_covs(model, train_loader, device='cuda:7'):
    """
    计算真实数据的类别原型、协方差矩阵及其秩
    
    参数:
        model: 提取特征的模型
        train_loader: 训练数据加载器
        device: 计算设备
    
    返回:
        protos: 类别原型列表
        covs: 协方差矩阵列表
        ranks: 矩阵秩列表
    """
    print(f"开始计算真实数据的原型和协方差...")
    model.eval()
    model.to(device)
    
    # 1. 从数据加载器提取特征和标签
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc="提取特征")):
            # 将数据移动到指定设备
            images = images.to(device)
            labels = labels.to(device)
            
            # 提取特征
            features = model(images)["features"]
            
            all_features.append(features.cpu())  # 移到CPU节省GPU内存
            all_labels.append(labels.cpu())
            
            # 定期打印进度
            if (batch_idx + 1) % 10 == 0:
                print(f"已处理 {batch_idx + 1}/{len(train_loader)} 批次")
    
    # 合并所有特征和标签
    features = torch.cat(all_features).to(device)
    labels = torch.cat(all_labels).to(device)
    
    print(f"数据集特征提取完成，共 {len(features)} 个样本")
    
    # 2. 为每个类别收集样本
    unique_labels = torch.unique(labels, sorted=True)
    print(f"数据集包含 {len(unique_labels)} 个类别")
    
    # 统计每个类别的样本数量
    class_counts = {}
    for class_id in unique_labels:
        mask = labels == class_id
        class_counts[class_id.item()] = mask.sum().item()
    
    # 打印类别分布情况
    min_samples = min(class_counts.values())
    max_samples = max(class_counts.values())
    avg_samples = sum(class_counts.values()) / len(class_counts)
    
    print(f"类别样本统计: 最少 {min_samples}, 最多 {max_samples}, 平均 {avg_samples:.1f}")
    
    # 3. 计算每个类的原型、协方差矩阵和矩阵秩
    protos = []
    covs = []
    ranks = []
    condition_numbers = []
    feature_dim = features.shape[1]  # 特征维度
    
    # 存储每个类别的特征值（用于绘图）
    class_eigenvalues = {}
    
    for class_id in tqdm(unique_labels, desc="计算统计量"):
        class_id_item = class_id.item()
        mask = labels == class_id
        class_features = features[mask]
        class_count = len(class_features)
        
        if class_count < 2:  # 需要至少2个样本计算协方差
            print(f"警告: 类别 {class_id_item} 样本数量不足({class_count})，跳过")
            continue
        
        # 计算原型（均值向量）
        proto = class_features.mean(dim=0)
        
        # 计算协方差矩阵
        # 中心化数据
        centered_features = class_features - proto.unsqueeze(0)
        
        # 处理大矩阵内存问题 - 分块计算
        if class_count > 10000 or feature_dim > 1000:
            print(f"类别 {class_id_item} 样本数量大，使用分块计算")
            # 分块计算协方差矩阵，避免内存溢出
            cov = torch.zeros((feature_dim, feature_dim), device=device)
            chunk_size = 1000  # 调整块大小以适应内存
            
            for i in range(0, class_count, chunk_size):
                end = min(i + chunk_size, class_count)
                chunk = centered_features[i:end]
                # 累积协方差
                cov += torch.mm(chunk.t(), chunk)
            
            cov = cov / (class_count - 1)  # 无偏估计
        else:
            # 直接计算协方差矩阵
            cov = torch.mm(centered_features.t(), centered_features) / (class_count - 1)
        
        # 计算矩阵的特征值、秩和条件数
        try:
            # 使用eigh直接计算特征值（升序）
            eigenvalues, _ = torch.linalg.eigh(cov)
            # 将特征值从大到小排序
            eigenvalues = eigenvalues.flip(0)
            
            # 保存特征值用于后续绘图
            class_eigenvalues[class_id_item] = eigenvalues.cpu().numpy()
            
            # 计算矩阵的秩
            threshold = 1e-5 * torch.max(eigenvalues)
            rank = torch.sum(eigenvalues > threshold).item()
            
            # 计算条件数
            non_zero_vals = eigenvalues[eigenvalues > threshold]
            condition_number = (non_zero_vals[0] / non_zero_vals[-1]).item() if len(non_zero_vals) > 0 else float('inf')
            condition_numbers.append(condition_number)
        except Exception as e:
            print(f"计算类别 {class_id_item} 的特征值分解时出错: {e}")
            rank = 0
            condition_number = float('inf')
        
        protos.append(proto)
        covs.append(cov)
        ranks.append(rank)
        
        # 打印详细信息
        print(f"类别 {class_id_item}:")
        print(f"  - 样本数: {class_count}")
        print(f"  - 特征维度: {proto.shape}")
        print(f"  - 协方差矩阵维度: {cov.shape}")
        print(f"  - 矩阵秩: {rank}/{cov.shape[0]}")
        print(f"  - 条件数: {condition_number:.2e}")
        if len(eigenvalues) > 0:
            print(f"  - 最大特征值: {eigenvalues[0]:.4e}")
            print(f"  - 最小特征值: {eigenvalues[-1]:.4e}")
    
    # 4. 打印统计信息
    avg_rank = sum(ranks) / len(ranks) if ranks else 0
    max_rank = max(ranks) if ranks else 0
    min_rank = min(ranks) if ranks else 0
    avg_condition = sum(condition_numbers) / len(condition_numbers) if condition_numbers else float('inf')
    
    print("\n===== 真实数据统计信息 =====")
    print(f"处理类别数: {len(protos)}")
    print(f"总样本数: {len(features)}")
    print(f"特征维度: {feature_dim}")
    print(f"平均每类样本数: {avg_samples:.1f}")
    print(f"平均矩阵秩: {avg_rank:.2f} ({avg_rank/feature_dim*100:.1f}%的特征维度)")
    print(f"最大矩阵秩: {max_rank}")
    print(f"最小矩阵秩: {min_rank}")
    print(f"平均条件数: {avg_condition:.2e}")
    
    # 5. 基础可视化结果
    try:
        import matplotlib.pyplot as plt
        import numpy as np
        
        # 准备数据
        class_ids = list(class_counts.keys())
        counts = list(class_counts.values())
        
        # 排序以便更好地可视化
        sorted_indices = np.argsort(class_ids)
        sorted_class_ids = [class_ids[i] for i in sorted_indices]
        sorted_counts = [counts[i] for i in sorted_indices]
        
        # 样本分布图
        plt.figure(figsize=(14, 6))
        plt.subplot(1, 2, 1)
        plt.bar(range(len(sorted_class_ids)), sorted_counts)
        plt.xlabel('Class')
        plt.ylabel('Sample Count')
        plt.title('Dataset Class Distribution')
        if len(sorted_class_ids) > 20:
            plt.xticks(range(0, len(sorted_class_ids), len(sorted_class_ids)//10), 
                      [sorted_class_ids[i] for i in range(0, len(sorted_class_ids), len(sorted_class_ids)//10)])
        
        # 秩分布图
        plt.subplot(1, 2, 2)
        plt.hist(ranks, bins=min(20, max(ranks) if ranks else 1), alpha=0.7)
        plt.axvline(avg_rank, color='r', linestyle='--', 
                   label=f"Mean Rank: {avg_rank:.2f}")
        plt.xlabel('Matrix Rank')
        plt.ylabel('Class Count')
        plt.title('Covariance Matrix Rank Distribution')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('real_data_stats.png')
        print("Statistics chart saved to real_data_stats.png")
        
        # 条件数分布
        plt.figure(figsize=(10, 6))
        log_condition = [np.log10(c) if c < float('inf') else 20 for c in condition_numbers]
        plt.hist(log_condition, bins=20, alpha=0.7)
        plt.xlabel('Condition Number (log10)')
        plt.ylabel('Class Count')
        plt.title('Covariance Matrix Condition Number Distribution')
        plt.savefig('condition_number_dist.png')
        print("Condition number distribution saved to condition_number_dist.png")
        
    except Exception as e:
        print(f"基础统计可视化失败: {e}")
    
    # 6. 绘制特征值分布图（新增功能）
    if class_eigenvalues:
        try:
            # 提取数据集名称（如果可用）
            dataset_name = getattr(train_loader.dataset, 'name', 'real_data')
            
            # 调用通用特征值绘图函数
            plot_eigenvalues(
                eigenvalues_dict=class_eigenvalues,
                output_prefix=f"real_eigenvalues",
                max_classes=len(unique_labels),
                max_top_eigenvalues=512,
                figsize_main=(15, 10),
                figsize_top=(12, 6),
                selected_classes=None,  # 自动选择类别
                plot_linear=True,
                plot_log=True,
                plot_cumulative=True,
                plot_condition=True,
                plot_top=True,
                dpi=300
            )
        except Exception as e:
            print(f"绘制特征值分布图时出错: {e}")
    
    return protos, covs, ranks

def compute_l2_distance(oracle_protos, protos):
    """
    计算两组原型向量之间的欧式距离
    
    参数:
        oracle_protos: 列表，包含参考原型向量 (torch.tensor)
        protos: 列表，包含待评估原型向量 (torch.tensor)
        
    返回:
        distances: 列表，包含每对原型的欧式距离
        mean_distance: 平均欧式距离
        std_distance: 距离标准差
    """
    # 1. 检查输入有效性
    if len(oracle_protos) != len(protos):
        raise ValueError(f"原型列表长度不匹配: {len(oracle_protos)} vs {len(protos)}")
    
    # 2. 计算每对原型的欧式距离
    distances = []
    for i, (oracle_proto, proto) in enumerate(zip(oracle_protos, protos)):
        # 确保在同一设备上
        if oracle_proto.device != proto.device:
            proto = proto.to(oracle_proto.device)
            
        # 计算欧式距离 (L2范数)
        distance = torch.norm(oracle_proto - proto, p=2).item()
        distances.append(distance)
    
    # 3. 计算统计量
    mean_distance = sum(distances) / len(distances) if distances else 0
    std_distance = (
        torch.std(torch.tensor(distances)).item() 
        if len(distances) > 1 else 0
    )
    
    return distances, mean_distance, std_distance

# 升级版本：支持批量计算和可视化
def compute_proto_distances(oracle_protos, protos, return_details=False, visualize=False):
    """
    增强版原型距离计算，支持详细分析和可视化
    
    参数:
        oracle_protos: 参考原型列表
        protos: 评估原型列表
        return_details: 是否返回详细距离信息
        visualize: 是否生成距离分布图
        
    返回:
        字典，包含距离统计信息
    """
    # 转换为相同数据类型和设备
    device = oracle_protos[0].device
    oracle_tensor = torch.stack([p.float().to(device) for p in oracle_protos])
    proto_tensor = torch.stack([p.float().to(device) for p in protos]).view(len(oracle_protos), -1)
    current_task = len(oracle_protos) / 10
    print(f'oracle proto shape: {oracle_tensor.shape}')
    print(f'proto shape: {proto_tensor.shape}')
    # 批量计算欧式距离
    distances = torch.norm(oracle_tensor - proto_tensor, dim=1)
    print(f'distance shape: {distances.shape}')
    
    # 统计信息
    results = {
        'mean': distances.mean().item(),
        'std': distances.std().item(),
        'min': distances.min().item(),
        'max': distances.max().item(),
        'median': distances.median().item()
    }
    
    # 可视化距离分布
    if visualize:
        try:
            import matplotlib.pyplot as plt
            plt.figure(figsize=(10, 6))
            plt.hist(distances.cpu().numpy(), bins=20, alpha=0.7)
            plt.axvline(results['mean'], color='r', linestyle='--', label=f"Mean: {results['mean']:.4f}")
            plt.title("Prototype Vectors Euclidean Distance Distribution")
            plt.xlabel("Euclidean Distance")
            plt.ylabel("Frequency")
            plt.legend()
            plt.grid(alpha=0.3)
            plt.savefig(f"proto_{current_task}__distances.png")
            print(f"距离分布图已保存至 proto_distances.png")
        except Exception as e:
            print(f"可视化失败: {e}")
    
    # 返回详细信息
    if return_details:
        results['distances'] = distances.cpu().tolist()
        
    return results

def plot_eigenvalues(eigenvalues_dict, 
                     output_prefix="eigenvalues",
                     max_classes=10, 
                     max_top_eigenvalues=100,
                     figsize_main=(15, 10),
                     figsize_top=(12, 6),
                     selected_classes=None,
                     plot_linear=True,
                     plot_log=True,
                     plot_cumulative=True,
                     plot_condition=True,
                     plot_top=True,
                     dpi=300):
    """
    绘制协方差矩阵特征值的通用函数
    
    参数:
        eigenvalues_dict (dict): 字典,键为类别ID,值为特征值数组（已从大到小排序）
        output_prefix (str): 输出文件名前缀
        max_classes (int): 最多绘制多少个类别
        max_top_eigenvalues (int): 在顶部特征值对比图中最多显示多少个特征值
        figsize_main (tuple): 主图的尺寸
        figsize_top (tuple): 顶部特征值图的尺寸
        selected_classes (list): 指定要绘制的类别ID列表,如果为None则自动选择
        plot_linear (bool): 是否绘制线性尺度图
        plot_log (bool): 是否绘制对数尺度图
        plot_cumulative (bool): 是否绘制累积方差图
        plot_condition (bool): 是否绘制条件数图
        plot_top (bool): 是否绘制顶部特征值对比图
        dpi (int): 图像分辨率
    
    返回:
        list: 保存的图像文件路径列表
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    
    if not eigenvalues_dict:
        print("没有特征值数据可绘制")
        return []
    
    saved_files = []
    output_prefix = "eigenvalues/" + output_prefix
    if not os.path.exists("eigenvalues/"):
        os.makedirs("eigenvalues/")
    # 选择要绘制的类别
    if selected_classes is None:
        # 自动选择有代表性的类别
        if len(eigenvalues_dict) > max_classes:
            # 等间隔采样
            indices = np.linspace(0, len(eigenvalues_dict)-1, max_classes, dtype=int)
            class_ids_to_plot = [list(eigenvalues_dict.keys())[i] for i in indices]
        else:
            class_ids_to_plot = list(eigenvalues_dict.keys())
    else:
        # 使用指定的类别，确保它们存在于数据中
        class_ids_to_plot = [c for c in selected_classes if c in eigenvalues_dict]
        
    # 确定要绘制的子图数量
    num_plots = sum([plot_linear, plot_log, plot_cumulative, plot_condition])
    if num_plots == 0:
        # 如果没有启用任何图表，启用所有图表
        plot_linear = plot_log = plot_cumulative = plot_condition = True
        num_plots = 4
    
    # 确定子图布局
    if num_plots <= 2:
        rows, cols = 1, num_plots
    else:
        rows, cols = 2, (num_plots + 1) // 2
    
    # 创建主图
    plt.figure(figsize=figsize_main)
    
    plot_index = 1
    
    # 1. 线性尺度特征值图
    if plot_linear:
        plt.subplot(rows, cols, plot_index)
        for class_id in class_ids_to_plot:
            eigenvalues = eigenvalues_dict[class_id]
            plt.plot(np.arange(len(eigenvalues)), eigenvalues, 
                     label=f"Class {class_id}")
        
        plt.title("Eigenvalues of Covariance Matrices (Linear Scale)")
        plt.xlabel("Eigenvalue Index")
        plt.ylabel("Eigenvalue Magnitude")
        plt.grid(alpha=0.3)
        
        # 如果类别少于等于max_classes，显示图例
        if len(class_ids_to_plot) <= max_classes:
            plt.legend()
        
        plot_index += 1
    
    # 2. 对数尺度特征值图
    if plot_log:
        plt.subplot(rows, cols, plot_index)
        for class_id in class_ids_to_plot:
            eigenvalues = eigenvalues_dict[class_id]
            plt.semilogy(np.arange(len(eigenvalues)), eigenvalues, 
                         label=f"Class {class_id}")
        
        plt.title("Eigenvalues of Covariance Matrices (Log Scale)")
        plt.xlabel("Eigenvalue Index")
        plt.ylabel("Eigenvalue Magnitude (log)")
        plt.grid(alpha=0.3)
        
        if len(class_ids_to_plot) <= max_classes:
            plt.legend()
        
        plot_index += 1
    
    # 3. 累积方差图
    if plot_cumulative:
        plt.subplot(rows, cols, plot_index)
        for class_id in class_ids_to_plot:
            eigenvalues = eigenvalues_dict[class_id]
            # 计算累积解释方差比例
            cumulative_var = np.cumsum(eigenvalues) / np.sum(eigenvalues)
            plt.plot(np.arange(len(eigenvalues)), cumulative_var, 
                     label=f"Class {class_id}")
        
        plt.title("Cumulative Explained Variance")
        plt.xlabel("Number of Eigenvalues")
        plt.ylabel("Cumulative Explained Variance Ratio")
        plt.grid(alpha=0.3)
        plt.axhline(y=0.95, color='r', linestyle='--', label="95% Explained Variance")
        
        if len(class_ids_to_plot) <= max_classes:
            plt.legend()
        
        plot_index += 1
    
    # 4. 条件数图
    if plot_condition:
        plt.subplot(rows, cols, plot_index)
        for class_id in class_ids_to_plot:
            eigenvalues = eigenvalues_dict[class_id]
            # 计算不同截断下的条件数
            condition_numbers = [eigenvalues[0] / eigenvalues[i] if eigenvalues[i] > 0 
                                 else float('inf') for i in range(1, len(eigenvalues))]
            plt.semilogy(np.arange(1, len(eigenvalues)), condition_numbers, 
                         label=f"Class {class_id}")
        
        plt.title("Condition Number vs. Truncation Level")
        plt.xlabel("Truncation Index")
        plt.ylabel("Condition Number (log)")
        plt.grid(alpha=0.3)
        
        if len(class_ids_to_plot) <= max_classes:
            plt.legend()
    
    plt.tight_layout()
    
    # 保存主图
    main_file = f"{output_prefix}_all.png"
    plt.savefig(main_file, dpi=dpi)
    print(f"特征值分布图已保存至: {main_file}")
    saved_files.append(main_file)
    
    # 绘制顶部特征值对比图
    if plot_top:
        plt.figure(figsize=figsize_top)
        
        # 每个类别取前N个特征值
        for class_id in class_ids_to_plot:
            eigenvalues = eigenvalues_dict[class_id]
            n_eigen = min(len(eigenvalues), max_top_eigenvalues)
            plt.semilogy(np.arange(n_eigen), eigenvalues[:n_eigen], 
                         label=f"Class {class_id}")
        
        plt.title("Top Eigenvalues Comparison")
        plt.xlabel("Eigenvalue Index")
        plt.ylabel("Eigenvalue Magnitude (log)")
        plt.grid(alpha=0.3)
        plt.legend()
        
        plt.tight_layout()
        
        # 保存顶部特征值图
        top_file = f"{output_prefix}_top{max_top_eigenvalues}.png"
        plt.savefig(top_file, dpi=dpi)
        print(f"主要特征值对比图已保存至: {top_file}")
        saved_files.append(top_file)
    
    return saved_files  