import torch
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.manifold import TSNE
from scipy.spatial.distance import cdist


def visualize_pu_scores(self, all_test_features, all_test_labels, dataset):
    """可视化PU分数分布"""
    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'Noto Sans CJK SC']
    plt.rcParams['axes.unicode_minus'] = False

    # 获取PU分数
    all_pu_scores = self.pu_learner.predict(all_test_features.to(self.device)).cpu().numpy()

    # 修复标签匹配问题
    # 确保seenclasses和all_test_labels的数据类型一致
    if isinstance(self.seenclasses, torch.Tensor):
        seen_classes_set = set(self.seenclasses.cpu().numpy().tolist())
    else:
        seen_classes_set = set(self.seenclasses)

    # 确保all_test_labels的处理正确
    if isinstance(all_test_labels, torch.Tensor):
        test_labels_list = all_test_labels.cpu().numpy()
    else:
        test_labels_list = np.array(all_test_labels)

    # 修复标签匹配逻辑
    is_seen = np.array([int(label) in seen_classes_set for label in test_labels_list])

    # 调试信息：检查标签分布
    seen_count = np.sum(is_seen)
    unseen_count = len(is_seen) - seen_count
    print(f"标签分布检查: 已见类={seen_count}, 未见类={unseen_count}")
    print(f"PU分数范围: [{all_pu_scores.min():.4f}, {all_pu_scores.max():.4f}]")
    print(
        f"已见类样本的PU分数范围: [{all_pu_scores[is_seen].min():.4f}, {all_pu_scores[is_seen].max():.4f}]" if seen_count > 0 else "没有已见类样本")
    print(
        f"未见类样本的PU分数范围: [{all_pu_scores[~is_seen].min():.4f}, {all_pu_scores[~is_seen].max():.4f}]" if unseen_count > 0 else "没有未见类样本")

    # 检查数据有效性
    if seen_count == 0:
        print("错误: 没有找到已见类样本，检查标签匹配逻辑")
        return np.mean(all_pu_scores)  # 返回均值作为默认阈值

    if unseen_count == 0:
        print("错误: 没有找到未见类样本，检查标签匹配逻辑")
        return np.mean(all_pu_scores)  # 返回均值作为默认阈值

    from sklearn.metrics import roc_curve, auc
    fpr, tpr, thresholds = roc_curve(is_seen, all_pu_scores)
    roc_auc = auc(fpr, tpr)

    # 修复阈值选择逻辑
    # 确保不会选择极端阈值
    valid_indices = (tpr > 0) & (fpr < 1.0)  # 过滤掉极端情况
    if np.any(valid_indices):
        # 在有效范围内找最佳阈值
        youden_scores = tpr[valid_indices] - fpr[valid_indices]
        best_valid_idx = np.argmax(youden_scores)
        # 映射回原始索引
        valid_indices_array = np.where(valid_indices)[0]
        optimal_idx = valid_indices_array[best_valid_idx]
    else:
        # 如果没有有效的索引，使用传统方法
        optimal_idx = np.argmax(tpr - fpr)

    optimal_threshold = thresholds[optimal_idx]

    # 验证阈值合理性
    if optimal_threshold > all_pu_scores.max() or optimal_threshold < all_pu_scores.min():
        print(f"警告: 计算的阈值({optimal_threshold:.4f})超出PU分数范围，使用均值替代")
        optimal_threshold = np.mean(all_pu_scores)
        # 重新计算对应的TPR和FPR
        predictions = all_pu_scores > optimal_threshold
        tpr_new = np.sum(predictions & is_seen) / np.sum(is_seen)
        fpr_new = np.sum(predictions & ~is_seen) / np.sum(~is_seen)
        print(f"调整后阈值: {optimal_threshold:.4f}, 此时TPR={tpr_new:.4f}, FPR={fpr_new:.4f}")
    else:
        print(f"最佳阈值: {optimal_threshold:.4f}, 此时TPR={tpr[optimal_idx]:.4f}, FPR={fpr[optimal_idx]:.4f}")

    # 绘制分布直方图
    plt.figure(figsize=(10, 6))
    plt.hist(all_pu_scores[is_seen], bins=50, alpha=0.5, label='Seen Classes')
    plt.hist(all_pu_scores[~is_seen], bins=50, alpha=0.5, label='Unseen Classes')
    # plt.axvline(x=optimal_threshold, color='r', linestyle='--', label=f'Threshold={optimal_threshold:.4f}')
    plt.legend()
    plt.title('PU Score Distribution')
    plt.xlabel('PU Score')
    plt.ylabel('Count')
    plt.savefig(f'PU_{dataset}.png')
    plt.close()  # 添加关闭图像以释放内存

    # 绘制ROC曲线
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.3f}')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for PU Classification')
    plt.legend(loc='lower right')
    plt.savefig(f'PU_ROC_{dataset}.png')
    plt.close()  # 添加关闭图像以释放内存

    return optimal_threshold

def visualize_distributions(self, train_features, test_seen_features, test_unseen_features, dataset):
    """distribution visualization"""
    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'Noto Sans CJK SC']
    plt.rcParams['axes.unicode_minus'] = False

    with torch.no_grad():
        # 计算PU分数
        train_scores = self.pu_learner.predict(train_features.to(self.device))
        seen_scores = self.pu_learner.predict(test_seen_features.to(self.device))
        unseen_scores = self.pu_learner.predict(test_unseen_features.to(self.device))

        # 转换为numpy形式
        train_np = train_scores.cpu().numpy()
        seen_np = seen_scores.cpu().numpy()
        unseen_np = unseen_scores.cpu().numpy()

        plt.figure(figsize=(12, 8))

        # 核密度估计PU分布
        plt.subplot(2, 2, 1)
        sns.kdeplot(train_np, label='训练集(正样本)', fill=True)
        sns.kdeplot(seen_np, label='已见类测试集', fill=True)
        sns.kdeplot(unseen_np, label='未见类测试集', fill=True)

        plt.axvline(train_np.mean(), color='blue', linestyle='--', alpha=0.7)
        plt.axvline(seen_np.mean(), color='orange', linestyle='--', alpha=0.7)
        plt.axvline(unseen_np.mean(), color='green', linestyle='--', alpha=0.7)

        threshold = (seen_np.mean() + unseen_np.mean()) / 2
        plt.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='决策阈值')

        plt.title('PU分数分布 (核密度估计)')
        plt.xlabel('分数')
        plt.ylabel('密度')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 分数直方图对比
        plt.subplot(2, 2, 2)
        plt.hist(seen_np, bins=30, alpha=0.5, label='已见类测试集')
        plt.hist(unseen_np, bins=30, alpha=0.5, label='未见类测试集')
        plt.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='决策阈值')
        plt.title('分数直方图对比')
        plt.xlabel('分数')
        plt.ylabel('样本数量')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 累积分布函数
        plt.subplot(2, 2, 3)
        plt.hist(seen_np, bins=50, alpha=0.5, label='已见类测试集', cumulative=True, density=True)
        plt.hist(unseen_np, bins=50, alpha=0.5, label='未见类测试集', cumulative=True, density=True)
        plt.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='决策阈值')
        plt.title('累积分布函数')
        plt.xlabel('分数')
        plt.ylabel('累积概率')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 分数分布箱线图
        plt.subplot(2, 2, 4)
        data_to_plot = [train_np, seen_np, unseen_np]
        plt.boxplot(data_to_plot, labels=['训练集', '已见类', '未见类'])
        plt.axhline(threshold, color='red', linestyle='-', alpha=0.7, label='决策阈值')
        plt.title('分数分布箱线图')
        plt.ylabel('分数')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'pu_score_detailed_distribution_{dataset}.png', dpi=300)
        # plt.savefig(f'分开评估.png', dpi=300)
        plt.close()

def visualize_distributions2(self, train_features, test_seen_features, test_unseen_features, dataset):
    """distribution visualization"""
    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS', 'Noto Sans CJK SC']
    plt.rcParams['axes.unicode_minus'] = False

    with torch.no_grad():
        # 计算PU分数
        train_scores = self.pu_learner.predict(train_features.to(self.device))
        seen_scores = self.pu_learner.predict(test_seen_features.to(self.device))
        unseen_scores = self.pu_learner.predict(test_unseen_features.to(self.device))

        # 转换为numpy形式
        train_np = train_scores.cpu().numpy()
        seen_np = seen_scores.cpu().numpy()
        unseen_np = unseen_scores.cpu().numpy()

        plt.figure(figsize=(12, 8))
        # plt.figure(figsize=(8, 6))

        # 核密度估计PU分布
        sns.kdeplot(train_np, label='Training set (Positive)', fill=True)
        sns.kdeplot(seen_np, label='Seen-class Test', fill=True)
        sns.kdeplot(unseen_np, label='Unseen-class Test', fill=True)

        plt.axvline(train_np.mean(), color='blue', linestyle='--', alpha=0.7)
        plt.axvline(seen_np.mean(), color='orange', linestyle='--', alpha=0.7)
        plt.axvline(unseen_np.mean(), color='green', linestyle='--', alpha=0.7)

        threshold = (seen_np.mean() + unseen_np.mean()) / 2
        plt.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='Optimal Threshold')

        plt.title('Kernel Density Estimation')
        plt.xlabel('PU Score')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'PU_Kernel_{dataset}.png')
        plt.close()


def visualize_pu_scores_enhanced(self, train_features, test_seen_features, test_unseen_features, dataset):
    """增强版的PU分数分布可视化"""

    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
    plt.rcParams['axes.unicode_minus'] = False

    with torch.no_grad():
        # 计算PU分数
        train_scores = self.pu_learner.predict(train_features.to(self.device))
        seen_scores = self.pu_learner.predict(test_seen_features.to(self.device))
        unseen_scores = self.pu_learner.predict(test_unseen_features.to(self.device))

        # 转换为numpy形式
        train_np = train_scores.cpu().numpy()
        seen_np = seen_scores.cpu().numpy()
        unseen_np = unseen_scores.cpu().numpy()

        # 计算决策阈值和分类指标
        threshold = (seen_np.mean() + unseen_np.mean()) / 2

        # 真实标签（0 - 未见类，1 - 已见类）
        true_seen = np.ones(len(seen_np))
        true_unseen = np.zeros(len(unseen_np))

        # 预测标签
        pred_seen = (seen_np > threshold).astype(int)
        pred_unseen = (unseen_np > threshold).astype(int)

        # 创建多面板可视化
        fig = plt.figure(figsize=(20, 16))

        # 带有决策阈值的密度图
        ax1 = fig.add_subplot(2, 2, 1)
        sns.kdeplot(train_np, label='训练集(正样本)', fill=True, ax=ax1)
        sns.kdeplot(seen_np, label='已见类测试集', fill=True, ax=ax1)
        sns.kdeplot(unseen_np, label='未见类测试集', fill=True, ax=ax1)

        ax1.axvline(train_np.mean(), color='blue', linestyle='--', alpha=0.7)
        ax1.axvline(seen_np.mean(), color='orange', linestyle='--', alpha=0.7)
        ax1.axvline(unseen_np.mean(), color='green', linestyle='--', alpha=0.7)
        ax1.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='决策阈值')

        ax1.set_title('PU分数分布与决策阈值', fontsize=16)
        ax1.set_xlabel('PU分数', fontsize=14)
        ax1.set_ylabel('密度', fontsize=14)
        ax1.legend(fontsize=12)
        ax1.grid(True, alpha=0.3)

        # 分类性能可视化 - ROC曲线
        ax2 = fig.add_subplot(2, 2, 2)

        # 合并已见类和未见类数据
        all_scores = np.concatenate([seen_np, unseen_np])
        all_labels = np.concatenate([true_seen, true_unseen])

        # 计算不同阈值下的真阳性率和假阳性率
        thresholds = np.linspace(np.min(all_scores), np.max(all_scores), 100)
        tpr = np.zeros_like(thresholds, dtype=float)
        fpr = np.zeros_like(thresholds, dtype=float)

        for i, t in enumerate(thresholds):
            pred = (all_scores > t).astype(int)
            tp = np.sum((pred == 1) & (all_labels == 1))
            fp = np.sum((pred == 1) & (all_labels == 0))
            tn = np.sum((pred == 0) & (all_labels == 0))
            fn = np.sum((pred == 0) & (all_labels == 1))

            tpr[i] = tp / (tp + fn) if (tp + fn) > 0 else 0
            fpr[i] = fp / (fp + tn) if (fp + tn) > 0 else 0

        # 计算AUC
        from sklearn.metrics import auc
        roc_auc = auc(fpr, tpr)

        # 绘制ROC曲线
        ax2.plot(fpr, tpr, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
        ax2.plot([0, 1], [0, 1], 'k--', label='随机猜测')
        ax2.fill_between(fpr, tpr, alpha=0.2)
        ax2.set_title('ROC曲线', fontsize=16)
        ax2.set_xlabel('假阳性率', fontsize=14)
        ax2.set_ylabel('真阳性率', fontsize=14)
        ax2.legend(fontsize=12)
        ax2.grid(True, alpha=0.3)

        # 混淆矩阵热图
        ax3 = fig.add_subplot(2, 2, 3)

        # 计算混淆矩阵
        from sklearn.metrics import confusion_matrix
        combined_pred = np.concatenate([pred_seen, pred_unseen])
        cm = confusion_matrix(all_labels, combined_pred)

        # 绘制混淆矩阵
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax3)
        ax3.set_title(f'混淆矩阵 (准确率: {np.mean(combined_pred == all_labels):.3f})', fontsize=16)
        ax3.set_xlabel('预测标签', fontsize=14)
        ax3.set_ylabel('真实标签', fontsize=14)
        ax3.set_xticklabels(['未见类', '已见类'])
        ax3.set_yticklabels(['未见类', '已见类'])

        # PU分数的分类性能指标
        ax4 = fig.add_subplot(2, 2, 4)

        # 计算各种阈值上的性能指标
        precision = []
        recall = []
        f1_scores = []

        for t in thresholds:
            pred = (all_scores > t).astype(int)
            tp = np.sum((pred == 1) & (all_labels == 1))
            fp = np.sum((pred == 1) & (all_labels == 0))
            fn = np.sum((pred == 0) & (all_labels == 1))

            pre = tp / (tp + fp) if (tp + fp) > 0 else 0
            rec = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * pre * rec / (pre + rec) if (pre + rec) > 0 else 0

            precision.append(pre)
            recall.append(rec)
            f1_scores.append(f1)

        # 转换为numpy数组
        precision = np.array(precision)
        recall = np.array(recall)
        f1_scores = np.array(f1_scores)

        # 绘制性能曲线
        ax4.plot(thresholds, precision, label='精确率')
        ax4.plot(thresholds, recall, label='召回率')
        ax4.plot(thresholds, f1_scores, label='F1分数')
        ax4.axvline(threshold, color='red', linestyle='-', alpha=0.7, label='选择的阈值')

        ax4.set_title('不同阈值下的分类性能', fontsize=16)
        ax4.set_xlabel('阈值', fontsize=14)
        ax4.set_ylabel('性能指标', fontsize=14)
        ax4.legend(fontsize=12)
        ax4.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'pu_scores_performance_{dataset}.png', dpi=300)
        # print(f"增强版PU分数分布可视化已保存到 pu_scores_performance_{dataset}.png")
        plt.savefig(f'分开评估2.png', dpi=300)
        plt.close()

def visualize_feature_space(self, train_features, test_seen_features, test_unseen_features, dataset):
    """使用t-SNE对特征空间进行降维可视化"""

    # 设置字体
    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
    plt.rcParams['axes.unicode_minus'] = False

    # 采样数据以加速计算（每类最多500个样本）
    max_samples = 500
    train_subset = train_features[:min(max_samples, len(train_features))].cpu().numpy()
    seen_subset = test_seen_features[:min(max_samples, len(test_seen_features))].cpu().numpy()
    unseen_subset = test_unseen_features[:min(max_samples, len(test_unseen_features))].cpu().numpy()

    # 合并数据用于t-SNE
    combined_features = np.vstack([train_subset, seen_subset, unseen_subset])

    # 准备标签
    labels = np.concatenate([
        np.zeros(len(train_subset)),  # 训练集：0
        np.ones(len(seen_subset)),  # 已见类：1
        np.ones(len(unseen_subset)) * 2  # 未见类：2
    ])

    # 使用t-SNE降维
    print("正在执行t-SNE降维...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
    embedded = tsne.fit_transform(combined_features)

    # 绘制可视化
    plt.figure(figsize=(12, 10))

    # 创建DataFrame用于seaborn绘图
    import pandas as pd
    df = pd.DataFrame({
        'x': embedded[:, 0],
        'y': embedded[:, 1],
        'label': labels
    })

    # 使用不同颜色和标记绘制不同的类别
    palette = {0: 'blue', 1: 'green', 2: 'red'}
    sns.scatterplot(
        data=df, x='x', y='y', hue='label',
        palette=palette, alpha=0.8,
        markers={0: 'o', 1: 's', 2: '^'},
        style='label', s=100,
        hue_order=[0, 1, 2]
    )

    # 设置图例
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles, ['训练集', '已见类测试集', '未见类测试集'],
               title='数据集类型', fontsize=12, title_fontsize=14)

    plt.title('特征空间降维可视化 (t-SNE)', fontsize=16)
    plt.xlabel('t-SNE维度1', fontsize=14)
    plt.ylabel('t-SNE维度2', fontsize=14)
    plt.tight_layout()
    plt.savefig(f'feature_space_tsne_{dataset}.png', dpi=300)
    print(f"特征空间可视化已保存到 feature_space_tsne_{dataset}.png")
    plt.close()


def visualize_semantic_correction(self, att_orig, att_corrected, dataset):
    """可视化语义校正前后的变化"""

    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
    plt.rcParams['axes.unicode_minus'] = False

    # 转换为numpy数组
    att_orig_np = att_orig.cpu().numpy()
    att_corrected_np = att_corrected.cpu().numpy()

    # 获取已见类和未见类索引
    seenclasses = self.seenclasses.cpu().numpy()
    unseenclasses = self.unseenclasses.cpu().numpy()

    # 计算校正前后的相似度
    sim_matrix = np.zeros((att_orig_np.shape[0], 2))  # 第一列：与原始的相似度，第二列：校正量

    for i in range(att_orig_np.shape[0]):
        if np.sum(att_orig_np[i]) == 0:  # 跳过零向量
            continue

        orig = att_orig_np[i]
        corr = att_corrected_np[i]

        # 归一化
        orig_norm = orig / np.linalg.norm(orig)
        corr_norm = corr / np.linalg.norm(corr)

        # 计算相似度
        sim = np.dot(orig_norm, corr_norm)
        sim_matrix[i, 0] = sim

        # 计算校正量 (1-相似度)
        sim_matrix[i, 1] = 1 - sim

    # 创建四面板可视化
    fig = plt.figure(figsize=(20, 16))

    # 1. 已见类和未见类的校正量对比
    ax1 = fig.add_subplot(2, 2, 1)

    seen_correction = [sim_matrix[i, 1] for i in seenclasses if i < len(sim_matrix)]
    unseen_correction = [sim_matrix[i, 1] for i in unseenclasses if i < len(sim_matrix)]

    ax1.boxplot([seen_correction, unseen_correction], labels=['已见类', '未见类'])
    ax1.set_title('已见类与未见类的语义校正量对比', fontsize=16)
    ax1.set_ylabel('校正量 (1-相似度)', fontsize=14)
    ax1.grid(True, alpha=0.3)

    # 校正前后的属性向量热图
    ax2 = fig.add_subplot(2, 2, 2)

    # 随机选择一个已见类和一个未见类进行对比
    if len(seenclasses) > 0 and len(unseenclasses) > 0:
        seen_idx = np.random.choice(seenclasses)
        unseen_idx = np.random.choice(unseenclasses)

        # 提取最多20个维度用于可视化
        max_dims = 20
        feature_dim = min(max_dims, att_orig_np.shape[1])

        # 准备数据
        data_to_plot = np.vstack([
            att_orig_np[seen_idx, :feature_dim],
            att_corrected_np[seen_idx, :feature_dim],
            att_orig_np[unseen_idx, :feature_dim],
            att_corrected_np[unseen_idx, :feature_dim]
        ])

        # 绘制热图
        sns.heatmap(data_to_plot, cmap='coolwarm', center=0, ax=ax2)
        ax2.set_title(f'校正前后的属性向量对比 (前{feature_dim}维)', fontsize=16)
        ax2.set_yticks([0.5, 1.5, 2.5, 3.5])
        ax2.set_yticklabels([f'已见类{seen_idx}(原始)', f'已见类{seen_idx}(校正)',
                             f'未见类{unseen_idx}(原始)', f'未见类{unseen_idx}(校正)'])
        ax2.set_xlabel('属性维度', fontsize=14)
    else:
        ax2.text(0.5, 0.5, '没有足够的已见类和未见类数据',
                 horizontalalignment='center', verticalalignment='center')

    # 语义校正前后的类别相似度矩阵变化
    ax3 = fig.add_subplot(2, 2, 3)

    # 计算校正前的类间相似度
    sim_before = 1 - cdist(att_orig_np, att_orig_np, metric='cosine')

    # 限制最多显示20个类别
    max_classes = 20
    n_classes = min(max_classes, att_orig_np.shape[0])

    # 仅显示前n_classes个类别的相似度
    sim_before_subset = sim_before[:n_classes, :n_classes]

    # 绘制热图
    sns.heatmap(sim_before_subset, cmap='viridis', vmin=0, vmax=1, ax=ax3)
    ax3.set_title(f'校正前的类别相似度矩阵 (前{n_classes}个类别)', fontsize=16)
    ax3.set_xlabel('类别索引', fontsize=14)
    ax3.set_ylabel('类别索引', fontsize=14)

    # 校正后的类别相似度矩阵
    ax4 = fig.add_subplot(2, 2, 4)

    # 计算校正后的类间相似度
    sim_after = 1 - cdist(att_corrected_np, att_corrected_np, metric='cosine')

    # 仅显示前n_classes个类别的相似度
    sim_after_subset = sim_after[:n_classes, :n_classes]

    # 绘制热图
    sns.heatmap(sim_after_subset, cmap='viridis', vmin=0, vmax=1, ax=ax4)
    ax4.set_title(f'校正后的类别相似度矩阵 (前{n_classes}个类别)', fontsize=16)
    ax4.set_xlabel('类别索引', fontsize=14)
    ax4.set_ylabel('类别索引', fontsize=14)

    plt.tight_layout()
    plt.savefig(f'semantic_correction_visualization_{dataset}.png', dpi=300)
    print(f"语义校正可视化已保存到 semantic_correction_visualization_{dataset}.png")
    plt.close()

def visualize_training_process(self, loss_history, dataset):
    """可视化训练过程中的损失变化"""

    plt.rcParams['font.family'] = ['sans-serif']
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
    plt.rcParams['axes.unicode_minus'] = False

    # 创建图表
    plt.figure(figsize=(12, 8))

    # 绘制总损失曲线
    plt.plot(loss_history, label='总损失', linewidth=2)

    # 添加平滑曲线（使用移动平均）
    window_size = min(25, len(loss_history) // 4)
    if window_size > 1:
        smoothed = np.convolve(loss_history, np.ones(window_size) / window_size, mode='valid')
        plt.plot(np.arange(window_size - 1, len(loss_history)), smoothed,
                 label=f'平滑损失 (窗口={window_size})', linewidth=3, linestyle='--')

    # 标注最小损失点
    min_loss_idx = np.argmin(loss_history)
    min_loss = loss_history[min_loss_idx]
    plt.scatter(min_loss_idx, min_loss, color='red', s=100, zorder=5)
    plt.annotate(f'最小损失: {min_loss:.4f}',
                 xy=(min_loss_idx, min_loss), xytext=(min_loss_idx + 10, min_loss * 1.2),
                 arrowprops=dict(facecolor='black', shrink=0.05, width=1.5),
                 fontsize=12)

    # 设置图表属性
    plt.title('PU学习训练过程中的损失变化', fontsize=16)
    plt.xlabel('迭代轮次', fontsize=14)
    plt.ylabel('损失值', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)

    # 使用对数刻度
    if np.max(loss_history) / np.min(loss_history) > 10:
        plt.yscale('log')
        plt.title('PU学习训练过程中的损失变化 (对数刻度)', fontsize=16)

    plt.tight_layout()
    plt.savefig(f'training_loss_history_{dataset}.png', dpi=300)
    print(f"训练过程可视化已保存到 training_loss_history_{dataset}.png")
    plt.close()