import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from PUModels.ProbPULearner_Z import PULearner
from ZSLModels.PSVMA.models.engine.inferencer import eval_zs_gzsl
from util.report_model import *
from util.visual import visualize_distributions, visualize_pu_scores, visualize_distributions2


class PSVMAPUTrainer:
    def __init__(self, model, dataloader, device, lambda_prob_pu=1.0):
        self.model = model
        self.device = device
        self.batch_size = 50  # 默认批次大小

        # 直接从res参数获取已见类和未见类信息
        self.tr_dataloader, self.tu_loader, self.ts_loader, self.res = dataloader

        # 从res中获取必要信息
        # 使用train_id作为seenclasses，test_id作为unseenclasses
        self.seenclasses = torch.tensor(self.res['train_id'], device=device)
        self.unseenclasses = torch.tensor(self.res['test_id'], device=device)

        # 获取测试标签
        self.test_seen_label = torch.tensor(self.res['test_label_seen'], device=device)
        self.test_unseen_label = torch.tensor(self.res['test_label_unseen'], device=device)

        self.len_test_seen = len(self.test_seen_label)
        self.len_test_unseen = len(self.test_unseen_label)

        self.all_test_labels = torch.cat([
            self.test_seen_label,
            self.test_unseen_label
        ], dim=0)

        # 从模型中获取特征维度和类别数
        self.feat_channel = model.feat_channel
        self.cls_num = len(self.res['train_id']) + len(self.res['test_id'])  # 总类别数
        self.attritube_num = model.attritube_num if hasattr(model, 'attritube_num') else model.V.size(1)

        # 使用att_seen和att_unseen构建属性矩阵
        att_seen = torch.tensor(self.res['att_seen'], device=device)
        att_unseen = torch.tensor(self.res['att_unseen'], device=device)

        # 构建完整的属性矩阵
        self.att = torch.zeros((self.cls_num, att_seen.size(1)), device=device)
        for i, idx in enumerate(self.seenclasses):
            self.att[idx] = att_seen[i]
        for i, idx in enumerate(self.unseenclasses):
            self.att[idx] = att_unseen[i]

        # 初始化特征属性
        self.test_seen_features = None
        self.test_unseen_features = None

        # 特征维度应该与属性矩阵的列数一致
        feature_dim = self.att.size(1)

        # 初始化PU学习器
        self.pu_learner = PULearner(
            feature_dim=feature_dim,
            att_mat=self.att.cpu().numpy(),
            device=device,
            lambda_prob_pu=lambda_prob_pu
        )

    def _get_attribute_matrix(self, model):
        # 使用w2v_att和W矩阵构建映射
        w2v_att = model.w2v_att
        W = model.W

        # 计算属性矩阵
        att_matrix = torch.einsum('lw,wv->lv', w2v_att, W)
        return att_matrix

    def extract_features(self):
        """从PSVMA模型中提取语义特征向量"""
        self.model.eval()
        with torch.no_grad():
            # 提取训练集特征
            train_features = []
            for batch in self.tr_dataloader:
                # 获取批次中的特征
                if isinstance(batch, list) and len(batch) >= 1:
                    batch_feature = batch[0].to(self.device)
                else:
                    batch_feature = batch.to(self.device)

                # 提取PSVMA的语义表示特征
                semantic_feature = self._extract_semantic_feature(batch_feature)
                train_features.append(semantic_feature)

            train_features = torch.cat(train_features, dim=0)

            # 提取测试集特征（已见类）
            test_seen_features = []
            seen_feat_list = []
            for batch in self.ts_loader:
                if isinstance(batch, list) and len(batch) >= 1:
                    batch_feature = batch[0].to(self.device)
                else:
                    batch_feature = batch.to(self.device)

                seen_feat_list.append(batch_feature)
                semantic_feature = self._extract_semantic_feature(batch_feature)
                test_seen_features.append(semantic_feature)

            test_seen_features = torch.cat(test_seen_features, dim=0)

            # 提取测试集特征（未见类）
            test_unseen_features = []
            unseen_feat_list = []
            for batch in self.tu_loader:
                if isinstance(batch, list) and len(batch) >= 1:
                    batch_feature = batch[0].to(self.device)
                else:
                    batch_feature = batch.to(self.device)

                unseen_feat_list.append(batch_feature)
                semantic_feature = self._extract_semantic_feature(batch_feature)
                test_unseen_features.append(semantic_feature)

            test_unseen_features = torch.cat(test_unseen_features, dim=0)

            # 保存特征
            self.test_seen_features = test_seen_features
            self.test_unseen_features = test_unseen_features

        return train_features, test_seen_features, test_unseen_features

    # 对测试集特征进行域适应对齐
    def align_to_train_domain(self, features, train_std, train_mean, feature_type="", adaptive_alpha=True):
        # 计算当前特征的统计信息
        current_mean = features.mean(dim=0, keepdim=True)
        current_std = features.std(dim=0, keepdim=True) + 1e-8

        # 标准化到标准分布
        normalized = (features - current_mean) / current_std

        # 对齐到训练域分布
        aligned = normalized * train_std + train_mean

        # 自适应对齐强度：根据对齐需求调整强度
        if adaptive_alpha:
            # 计算对齐需求：距离目标越远，对齐强度越大
            mean_distance = torch.abs(current_mean - train_mean).mean().item()
            base_alpha = 0.7

            if "未见类" in feature_type and mean_distance > 0.1:
                alpha = min(0.9, base_alpha + mean_distance * 2.0)  # 最高0.9
            else:
                alpha = base_alpha
        else:
            alpha = 0.7

        final_features = alpha * aligned + (1 - alpha) * features

        print(
            f"  {feature_type}特征对齐完成 - 原始均值: {current_mean.mean().item():.4f}, 对齐后均值: {final_features.mean().item():.4f}, 对齐强度: {alpha:.2f}")

        return final_features

    def bidirectional_soft_align(self, train_features, test_seen_features, test_unseen_features):
        """
        双向软对齐：
        - 训练集向测试集轻微靠近
        - 已见类测试集向训练集强烈靠近
        - 未见类测试集被推离
        """
        # 计算统计量
        train_mean = train_features.mean(dim=0, keepdim=True)
        train_std = train_features.std(dim=0, keepdim=True) + 1e-8

        seen_mean = test_seen_features.mean(dim=0, keepdim=True)
        seen_std = test_seen_features.std(dim=0, keepdim=True) + 1e-8

        # 1. 计算中间目标域（加权平均）
        # 训练集权重更大（0.8），测试集权重较小（0.2）
        target_mean = 0.8 * train_mean + 0.2 * seen_mean
        target_std = 0.8 * train_std + 0.2 * seen_std

        print(f"\n双向对齐:")
        print(f"  原始训练集均值: {train_mean.mean().item():.4f}")
        print(f"  原始已见类均值: {seen_mean.mean().item():.4f}")
        print(f"  目标域均值: {target_mean.mean().item():.4f}")

        # 2. 轻微调整训练集（alpha=0.15，保持大部分特性）
        train_normalized = (train_features - train_mean) / train_std
        train_to_target = train_normalized * target_std + target_mean
        train_aligned = 0.15 * train_to_target + 0.85 * train_features

        print(
            f"  对齐后训练集均值: {train_aligned.mean().item():.4f} (移动了 {abs(train_aligned.mean().item() - train_mean.mean().item()):.4f})")

        # 3. 强烈对齐已见类测试集（alpha=0.95）
        seen_normalized = (test_seen_features - seen_mean) / seen_std
        seen_to_target = seen_normalized * target_std + target_mean
        seen_aligned = 0.95 * seen_to_target + 0.05 * test_seen_features

        print(
            f"  对齐后已见类均值: {seen_aligned.mean().item():.4f} (移动了 {abs(seen_aligned.mean().item() - seen_mean.mean().item()):.4f})")

        # 4. 未见类保持当前策略（推离目标域）
        unseen_mean = test_unseen_features.mean(dim=0, keepdim=True)
        unseen_std = test_unseen_features.std(dim=0, keepdim=True) + 1e-8

        unseen_normalized = (test_unseen_features - unseen_mean) / unseen_std
        unseen_to_target = unseen_normalized * target_std + target_mean
        unseen_aligned = 0.80 * unseen_to_target + 0.20 * test_unseen_features

        print(f"  对齐后未见类均值: {unseen_aligned.mean().item():.4f}")

        # 5. 验证对齐效果
        final_train_mean = train_aligned.mean().item()
        final_seen_mean = seen_aligned.mean().item()
        final_unseen_mean = unseen_aligned.mean().item()

        gap_train_seen = abs(final_train_mean - final_seen_mean)
        gap_seen_unseen = abs(final_seen_mean - final_unseen_mean)

        print(f"\n对齐后gap:")
        print(f"  训练集-已见类: {gap_train_seen:.4f} (应该很小)")
        print(f"  已见类-未见类: {gap_seen_unseen:.4f} (应该很大)")

        return train_aligned, seen_aligned, unseen_aligned

    def iterative_align_to_train_domain(self, features, train_std, train_mean,
                                        feature_type="", num_iterations=3):
        """迭代对齐，逐步消除域偏移"""
        aligned_features = features

        for iteration in range(num_iterations):
            current_mean = aligned_features.mean(dim=0, keepdim=True)
            current_std = aligned_features.std(dim=0, keepdim=True) + 1e-8

            # 计算当前距离
            mean_distance = torch.abs(current_mean - train_mean).mean().item()

            # 标准化
            normalized = (aligned_features - current_mean) / current_std

            # 对齐
            realigned = normalized * train_std + train_mean

            # 逐步增强对齐强度
            if "已见类" in feature_type:
                # 第一次对齐用较小的alpha，逐步增强
                alpha = 0.6 + 0.1 * (iteration + 1)  # 0.7, 0.8, 0.9
            else:
                alpha = 0.5 + 0.1 * (iteration + 1)  # 0.6, 0.7, 0.8

            aligned_features = alpha * realigned + (1 - alpha) * aligned_features

            print(f"  迭代 {iteration + 1}/{num_iterations}: "
                  f"均值距离={mean_distance:.4f}, α={alpha:.2f}")

            # 如果距离已经很小，提前停止
            if mean_distance < 0.05:
                print(f"  收敛! 均值距离 < 0.05")
                break

        return aligned_features

    def advanced_align_to_train_domain(self, features, train_features, feature_type=""):
        """高阶统计量对齐（均值、标准差、偏度、峰度）"""

        # 计算训练集的统计量
        train_mean = train_features.mean(dim=0, keepdim=True)
        train_std = train_features.std(dim=0, keepdim=True) + 1e-8

        # 一阶和二阶矩对齐
        current_mean = features.mean(dim=0, keepdim=True)
        current_std = features.std(dim=0, keepdim=True) + 1e-8

        # 标准化
        normalized = (features - current_mean) / current_std

        # 对齐到训练分布
        aligned_basic = normalized * train_std + train_mean

        # 三阶矩对齐（偏度）
        train_centered = train_features - train_mean
        test_centered = aligned_basic - train_mean

        train_skewness = (train_centered ** 3).mean(dim=0, keepdim=True) / (train_std ** 3)
        test_skewness = (test_centered ** 3).mean(dim=0, keepdim=True) / (train_std ** 3)

        # 偏度校正
        skewness_correction = train_skewness - test_skewness
        aligned_skewness = aligned_basic + 0.1 * skewness_correction * train_std

        # 自适应融合
        if "已见类" in feature_type:
            alpha = 0.95
        else:
            alpha = 0.80

        final_features = alpha * aligned_skewness + (1 - alpha) * features

        print(f"  {feature_type}高阶对齐 - α={alpha:.3f}")
        return final_features

    def coral_align(self, source_features, target_features, feature_type=""):
        """
        CORAL域适应：对齐协方差矩阵
        Source: 测试集（需要对齐）
        Target: 训练集（目标分布）
        """
        # 计算均值
        source_mean = source_features.mean(dim=0, keepdim=True)
        target_mean = target_features.mean(dim=0, keepdim=True)

        # 中心化
        source_centered = source_features - source_mean
        target_centered = target_features - target_mean

        # 计算协方差矩阵
        source_cov = torch.mm(source_centered.t(), source_centered) / (source_features.size(0) - 1)
        target_cov = torch.mm(target_centered.t(), target_centered) / (target_features.size(0) - 1)

        # 添加正则化确保可逆
        source_cov = source_cov + torch.eye(source_cov.size(0), device=source_cov.device) * 1e-5
        target_cov = target_cov + torch.eye(target_cov.size(0), device=target_cov.device) * 1e-5

        # Cholesky分解（协方差矩阵的平方根）
        try:
            source_cov_sqrt_inv = torch.linalg.cholesky(source_cov).inverse()
            target_cov_sqrt = torch.linalg.cholesky(target_cov)
        except:
            print(f"  警告: Cholesky分解失败，使用SVD")
            # 使用SVD作为备选
            U_s, S_s, _ = torch.svd(source_cov)
            U_t, S_t, _ = torch.svd(target_cov)
            source_cov_sqrt_inv = U_s @ torch.diag(1.0 / torch.sqrt(S_s + 1e-5)) @ U_s.t()
            target_cov_sqrt = U_t @ torch.diag(torch.sqrt(S_t)) @ U_t.t()

        # CORAL变换
        coral_transform = target_cov_sqrt @ source_cov_sqrt_inv

        # 应用变换
        source_aligned = torch.mm(source_centered, coral_transform.t()) + target_mean

        # 自适应融合
        if "已见类" in feature_type:
            alpha = 0.90
        else:
            alpha = 0.75

        final_features = alpha * source_aligned + (1 - alpha) * source_features

        print(f"  {feature_type}CORAL对齐 - α={alpha:.3f}")
        return final_features

    def _extract_semantic_feature(self, input_feature):
        """从PSVMA模型中提取语义特征表示"""
        batch_size = input_feature.shape[0]
        # 获取PSVMA的语义-视觉映射
        parts = torch.einsum('lw,wv->lv', self.model.w2v_att, self.model.W)
        parts = parts.expand(batch_size, -1, -1)

        # 获取patch特征
        patches = self.model.backbone_patch(input_feature)
        cls_token = self.model.cls_token.expand(batch_size, -1, -1)
        patches = torch.cat((cls_token, patches), dim=1)
        feats_0 = self.model.backbone_drop(patches + self.model.pos_embed)
        feats_0 = self.model.backbone_0(feats_0)
        feats_in = feats_0[:, 1:, :]

        # 获取Block输出的特征
        feats_out, _, _ = self.model.blocks(feats_in.transpose(1, 2), parts=parts)

        # 4. 第二阶段：backbone_1 + 第二轮Block
        patches_1 = torch.cat((cls_token, feats_out.transpose(1, 2)), dim=1)
        feats_1 = self.model.backbone_1(patches_1 + self.model.pos_embed)
        feats_1 = feats_1[:, 1:, :]
        feats_final, _, _ = self.model.blocks(feats_1.transpose(1, 2), parts=parts)

        # 5. 特征聚合
        out = self.model.avgpool1d(feats_final.view(batch_size, self.feat_channel, -1)).view(batch_size, -1)

        # 6. V矩阵变换到属性空间
        semantic_feature = torch.einsum('bc,cd->bd', out, self.model.V)

        return semantic_feature

    def compute_psvma_score(self, features, att_corrected):
        """使用PSVMA原始的评分机制"""
        # 1. L2归一化特征
        feat_norm = torch.norm(features, p=2, dim=1).unsqueeze(1).expand_as(features)
        features_normalized = features.div(feat_norm + 1e-5)

        # 2. L2归一化属性
        att_norm = torch.norm(att_corrected, p=2, dim=1).unsqueeze(1).expand_as(att_corrected)
        att_normalized = att_corrected.div(att_norm + 1e-5)

        # 3. 计算相似度并缩放
        similarity = torch.einsum('bd,nd->bn', features_normalized, att_normalized)

        # 4. 使用PSVMA的缩放因子
        scale = self.model.scale.item() if hasattr(self.model, 'scale') else 20.0
        similarity = similarity * scale

        return similarity

    def compute_logits(self, test_seen_embeds, test_unseen_embeds):
        """计算测试集特征的预测分数 (logits)"""
        # 获取属性数据
        att_seen = self.res['att_seen'].to(self.device)
        att_unseen = self.res['att_unseen'].to(self.device)

        # 构建完整的属性矩阵
        all_att = torch.zeros((self.cls_num, att_seen.size(1)), device=self.device)
        for i, idx in enumerate(self.seenclasses):
            all_att[idx] = att_seen[i]
        for i, idx in enumerate(self.unseenclasses):
            all_att[idx] = att_unseen[i]

        # 归一化属性矩阵
        all_att = F.normalize(all_att, p=2, dim=1)

        # 计算相似度得分，乘以缩放因子
        scale = self.model.scale.item() if hasattr(self.model, 'scale') else 20.0

        # 计算logits (相似度得分)
        test_seen_logits = torch.mm(test_seen_embeds, all_att.t()) * scale
        test_unseen_logits = torch.mm(test_unseen_embeds, all_att.t()) * scale

        # 合并所有测试logits
        all_test_logits = torch.cat([test_seen_logits, test_unseen_logits], dim=0)

        return all_test_logits

    def evaluate_pu(self, optimal_threshold, all_test_features, all_test_labels):
        """PU评估函数"""
        # 构建标签掩码 - 确保在正确的设备上
        len_seen = len(self.test_seen_label)
        len_unseen = len(self.test_unseen_label)

        seen_mask = torch.zeros(len_seen + len_unseen, dtype=torch.bool, device=self.device)
        seen_mask[:len_seen] = True
        unseen_mask = ~seen_mask

        print(f"评估样本分布: 已见类={seen_mask.sum()}, 未见类={unseen_mask.sum()}")

        pu_score = self.pu_learner.predict(all_test_features.to(self.device))
        is_seen_pred = (pu_score > optimal_threshold).to(self.device)

        # === 详细的混淆矩阵分析 ===
        # 确保所有张量都在同一设备上进行计算
        seen_mask = seen_mask.to(self.device)
        unseen_mask = unseen_mask.to(self.device)
        is_seen_pred = is_seen_pred.to(self.device)

        # True Positive: 真实已见类 & 预测已见类
        tp = (seen_mask & is_seen_pred).sum().item()
        # False Positive: 真实未见类 & 预测已见类
        fp = (unseen_mask & is_seen_pred).sum().item()
        # True Negative: 真实未见类 & 预测未见类
        tn = (unseen_mask & ~is_seen_pred).sum().item()
        # False Negative: 真实已见类 & 预测未见类
        fn = (seen_mask & ~is_seen_pred).sum().item()

        print(f"混淆矩阵:")
        print(f"  TP (已见→已见): {tp}")
        print(f"  FP (未见→已见): {fp}")
        print(f"  TN (未见→未见): {tn}")
        print(f"  FN (已见→未见): {fn}")

        # === 计算各种评估指标 ===
        # 召回率 (原有指标)
        seen_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        unseen_recall = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # 精确率
        seen_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        unseen_precision = tn / (tn + fn) if (tn + fn) > 0 else 0.0

        # F1分数
        seen_f1 = 2 * seen_precision * seen_recall / (seen_precision + seen_recall) if (
                                                                                               seen_precision + seen_recall) > 0 else 0.0
        unseen_f1 = 2 * unseen_precision * unseen_recall / (unseen_precision + unseen_recall) if (
                                                                                                         unseen_precision + unseen_recall) > 0 else 0.0

        # 总体准确率
        accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0.0

        # 谐波平均数 (原有指标，但更准确的计算)
        harmonic = 2 * seen_recall * unseen_recall / (seen_recall + unseen_recall) if (
                                                                                              seen_recall + unseen_recall) > 0 else 0.0

        print(f"\n=== PU分类评估结果 ===")
        print(f"最佳阈值: {optimal_threshold:.4f}")
        print(f"\n【召回率 (Recall)】")
        print(f"  已见类召回率: {seen_recall:.4f} (TP Rate)")
        print(f"  未见类召回率: {unseen_recall:.4f} (TN Rate)")
        print(f"\n【精确率 (Precision)】")
        print(f"  已见类精确率: {seen_precision:.4f}")
        print(f"  未见类精确率: {unseen_precision:.4f}")
        print(f"\n【F1分数】")
        print(f"  已见类F1: {seen_f1:.4f}")
        print(f"  未见类F1: {unseen_f1:.4f}")
        print(f"\n【综合指标】")
        print(f"  总体准确率: {accuracy:.4f}")
        print(f"  谐波平均数: {harmonic:.4f}")

        return pu_score

    def correct_att(self, pu_scores, all_test_embeds, all_test_logits):
        """基于PU学习的语义属性向量校正"""
        self.model.eval()
        print("\n" + "=" * 50)
        print("开始执行无标签的语义校准评估...")

        with torch.no_grad():
            att = self.att.clone().to("cpu")

            batch_size = 128

            # 分析PU分数分布
            pu_mean, pu_std = pu_scores.mean().item(), pu_scores.std().item()
            print(f"PU分数分布 - 均值={pu_mean:.4f}, 标准差={pu_std:.4f}")

            # 使用均值作为基本阈值
            threshold = pu_mean
            self.threshold = threshold

            if len(self.seenclasses) >= 300:
                threshold = threshold - pu_std * 0.5

            print(f"PU分类阈值: {threshold:.4f}")

            # 基本分类
            pred_seen_mask = pu_scores > threshold
            pred_unseen_mask = ~pred_seen_mask

            # 设置高置信度阈值
            high_conf_seen_threshold = threshold
            high_conf_unseen_threshold = threshold

            # 针对SUN数据集调整参数
            dataset_size = len(self.seenclasses)
            if dataset_size >= 300:
                high_conf_seen_threshold = threshold - 1.5 * pu_std
                high_conf_unseen_threshold = threshold + 1.5 * pu_std

            print(f"高置信度已见类阈值: {high_conf_seen_threshold:.4f}")
            print(f"高置信度未见类阈值: {high_conf_unseen_threshold:.4f}")

            # 筛选高置信度样本
            high_conf_seen_mask = (pu_scores > high_conf_seen_threshold).flatten()
            high_conf_unseen_mask = (pu_scores < high_conf_unseen_threshold).flatten()

            # 统计样本数量
            total_seen = pred_seen_mask.sum().item()
            total_unseen = pred_unseen_mask.sum().item()
            high_conf_seen = high_conf_seen_mask.sum().item()
            high_conf_unseen = high_conf_unseen_mask.sum().item()

            print(f"分类统计:")
            print(f"  总样本数: {len(pu_scores)}")
            print(
                f"  预测为已见类: {total_seen} (其中高置信度: {high_conf_seen}, {high_conf_seen / total_seen * 100:.1f}%)")
            print(
                f"  预测为未见类: {total_unseen} (其中高置信度: {high_conf_unseen}, {high_conf_unseen / total_unseen * 100:.1f}%)")

            att_corrected = att.clone()

            print("为校准获取伪标签...")

            # 获取已见类和未见类的伪标签
            self.pred_seen_indices = torch.nonzero(pred_seen_mask).squeeze()
            self.pred_unseen_indices = torch.nonzero(pred_unseen_mask).squeeze()

            # 获取高置信度的已见类和未见类样本索引
            high_conf_seen_indices = torch.nonzero(high_conf_seen_mask).squeeze()
            high_conf_unseen_indices = torch.nonzero(high_conf_unseen_mask).squeeze()

            min_sample = 2 if dataset_size <= 300 else 1

            # 处理高置信度已见类样本
            if high_conf_seen > 0:
                print(f"处理 {high_conf_seen} 个高置信度的已见类样本")

                # 存储每个类别的样本嵌入和计数
                class_embeddings = {cls.item(): [] for cls in self.seenclasses}
                class_counts = {cls.item(): 0 for cls in self.seenclasses}

                # 分批处理数据
                for i in range(0, len(high_conf_seen_indices), batch_size):
                    end = min(i + batch_size, len(high_conf_seen_indices))
                    batch_indices = high_conf_seen_indices[i:end]

                    # 逐个处理样本
                    for idx in batch_indices:
                        embedding = all_test_embeds[idx:idx + 1].cpu()
                        logits = all_test_logits[idx:idx + 1].cpu()

                        # 只考虑已见类中的最高分类
                        seen_scores = logits[0, self.seenclasses]
                        max_seen_idx = torch.argmax(seen_scores)

                        # pseudo_label：该PU预测为已见类样本在原模型中的类别伪标签
                        pseudo_label = self.seenclasses[max_seen_idx].item()

                        # 存储样本嵌入和更新计数
                        class_embeddings[pseudo_label].append(embedding[0].numpy())
                        class_counts[pseudo_label] += 1

                # 使用累积的嵌入进行已见类语义校准
                print("使用高置信度伪标签校准已见类语义...")

                seen_calibrated_count = 0
                for cls in self.seenclasses.cpu().numpy():
                    sample_count = class_counts[cls]

                    if sample_count >= min_sample:
                        cls_embeds = np.array(class_embeddings[cls])
                        cls_centroid = np.mean(cls_embeds, axis=0)

                        # 计算原始语义与聚类质心的相似度
                        original_att = att[cls].cpu().numpy()
                        cos_sim = np.dot(original_att, cls_centroid) / (
                                np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))

                        # 自适应权重 - 相似度越高，保留越多原始语义
                        alpha = max(self.min, min(self.max, cos_sim))

                        # 校正语义
                        corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
                        corrected_att = corrected_att / np.linalg.norm(corrected_att)

                        # 更新属性矩阵
                        att_corrected[cls] = torch.from_numpy(corrected_att).float()
                        seen_calibrated_count += 1

                        print(
                            f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")

                if seen_calibrated_count == 0:
                    print("  没有足够的高置信度样本进行已见类校准")

            # 处理高置信度未见类样本
            if high_conf_unseen > 0:
                print(f"处理 {high_conf_unseen} 个高置信度的未见类样本")

                # 存储每个类别的样本嵌入和计数
                class_embeddings = {cls.item(): [] for cls in self.unseenclasses}
                class_counts = {cls.item(): 0 for cls in self.unseenclasses}

                # 分批处理数据
                for i in range(0, len(high_conf_unseen_indices), batch_size):
                    end = min(i + batch_size, len(high_conf_unseen_indices))
                    batch_indices = high_conf_unseen_indices[i:end]

                    # 逐个处理样本
                    for idx in batch_indices:
                        embedding = all_test_embeds[idx:idx + 1].cpu()
                        logits = all_test_logits[idx:idx + 1].cpu()

                        # 只考虑未见类中的最高分类
                        unseen_scores = logits[0, self.unseenclasses]
                        max_unseen_idx = torch.argmax(unseen_scores)

                        # pseudo_label：该PU预测为未见类样本在原模型中的类别伪标签
                        pseudo_label = self.unseenclasses[max_unseen_idx].item()

                        # 存储样本嵌入和更新计数
                        class_embeddings[pseudo_label].append(embedding[0].numpy())
                        class_counts[pseudo_label] += 1

                # 使用累积的嵌入进行未见类语义校准
                print("使用高置信度伪标签校准未见类语义...")

                unseen_calibrated_count = 0
                for cls in self.unseenclasses.cpu().numpy():
                    sample_count = class_counts[cls]

                    if sample_count >= min_sample:
                        cls_embeds = np.array(class_embeddings[cls])
                        cls_centroid = np.mean(cls_embeds, axis=0)

                        # 计算原始语义与聚类质心的相似度
                        original_att = att[cls].cpu().numpy()
                        cos_sim = np.dot(original_att, cls_centroid) / (
                                np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))

                        # 自适应权重 - 相似度越高，保留越多原始语义
                        alpha = max(self.min, min(self.max, cos_sim))

                        # 校正语义
                        corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
                        corrected_att = corrected_att / np.linalg.norm(corrected_att)

                        # 更新属性矩阵
                        att_corrected[cls] = torch.from_numpy(corrected_att).float()
                        unseen_calibrated_count += 1

                        print(
                            f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")

                if unseen_calibrated_count == 0:
                    print("  没有足够的高置信度样本进行未见类校准")

            # 将校正后的属性矩阵转移到设备上
            att_corrected = att_corrected.to(self.device)
            return att_corrected

    def eval_model(self, att_corrected, all_test_labels, all_test_embeds, pu_scores):
        """使用经过PU分类器校正后的语义属性向量来对PSVMA模型做预测"""
        print("使用校正后的语义进行最终预测...")
        batch_size = 128

        # 获取测试数据
        test_seen_label = all_test_labels[:self.len_test_seen].to(self.device)
        test_unseen_label = all_test_labels[self.len_test_seen:].to(self.device)

        test_seen_embeds = all_test_embeds[:self.len_test_seen].to(self.device)
        test_unseen_embeds = all_test_embeds[self.len_test_seen:].to(self.device)

        # 记录原始样本索引，仅用于最终结果分离
        seen_indices = torch.arange(self.len_test_seen)
        unseen_indices = torch.arange(self.len_test_unseen) + self.len_test_seen

        # 单独处理已见类和未见类，用于最终评估
        seen_correct = 0
        seen_total = 0

        # PU分数增强权重，控制已见类/未见类倾向性
        alpha = self.alpha  # 增强强度

        for i in range(0, self.len_test_seen, batch_size):
            end = min(i + batch_size, self.len_test_seen)
            batch_labels = test_seen_label[i:end]
            # 获取特征嵌入
            features_embed = test_seen_embeds[i:end]

            # 计算与所有类别属性的相似度
            similarity = torch.mm(features_embed, att_corrected.t())

            # 获取PU分数，索引对应原始加载顺序
            idx_start = i
            idx_end = min(i + batch_size, self.len_test_seen)
            if idx_end > idx_start:
                current_pu_scores = pu_scores[seen_indices[idx_start:idx_end]].to(self.device)

                # tanh转换到[-1,1]区间
                # seen_weight = torch.tanh(current_pu_scores-self.threshold).unsqueeze(1)
                seen_weight = torch.sigmoid(current_pu_scores - self.threshold).unsqueeze(1)

                # 分别增强/减弱已见类和未见类的分数
                seen_sim = similarity[:, self.seenclasses]
                unseen_sim = similarity[:, self.unseenclasses]

                # 增强已见类分数，减弱未见类分数
                seen_enhanced = seen_sim * (1.0 + alpha * seen_weight)
                unseen_enhanced = unseen_sim * (1.0 - alpha * seen_weight)

                # 合并分数
                enhanced_sim = torch.zeros_like(similarity)
                enhanced_sim[:, self.seenclasses] = seen_enhanced
                enhanced_sim[:, self.unseenclasses] = unseen_enhanced

                # 获取预测标签
                _, predicted = torch.max(enhanced_sim, dim=1)
                seen_correct += (predicted == batch_labels).sum().item()
                seen_total += len(batch_labels)

            torch.cuda.empty_cache()

        # 处理未见类测试集
        unseen_correct = 0
        unseen_total = 0
        for i in range(0, self.len_test_unseen, batch_size):
            end = min(i + batch_size, self.len_test_unseen)
            batch_labels = test_unseen_label[i:end]

            # 获取特征嵌入
            features_embed = test_unseen_embeds[i:end]

            # 计算与所有类别属性的相似度
            similarity = torch.mm(features_embed, att_corrected.t())

            # 获取PU分数，索引对应原始加载顺序
            idx_start = i
            idx_end = min(i + batch_size, self.len_test_unseen)
            if idx_end > idx_start:
                # 调整索引偏移量
                current_pu_scores = pu_scores[unseen_indices[idx_start:idx_end]].to(self.device)

                # 转换PU分数，负值表示更可能是未见类
                # unseen_weight = torch.tanh(current_pu_scores-self.threshold).unsqueeze(1)
                unseen_weight = torch.sigmoid(current_pu_scores - self.threshold).unsqueeze(1)

                # 分别增强/减弱已见类和未见类的分数
                seen_sim = similarity[:, self.seenclasses]
                unseen_sim = similarity[:, self.unseenclasses]

                seen_enhanced = seen_sim * (1.0 - alpha * unseen_weight)
                unseen_enhanced = unseen_sim * (1.0 + alpha * unseen_weight)

                # 合并分数
                enhanced_sim = torch.zeros_like(similarity)
                enhanced_sim[:, self.seenclasses] = seen_enhanced
                enhanced_sim[:, self.unseenclasses] = unseen_enhanced

                # 获取预测标签
                _, predicted = torch.max(enhanced_sim, dim=1)
                unseen_correct += (predicted == batch_labels).sum().item()
                unseen_total += len(batch_labels)

        # 计算最终性能
        seen_acc = seen_correct / seen_total if seen_total > 0 else 0
        unseen_acc = unseen_correct / unseen_total if unseen_total > 0 else 0
        H = (2 * seen_acc * unseen_acc) / (seen_acc + unseen_acc) if seen_acc + unseen_acc > 0 else 0

        print(f"无标签语义校准评估结果:")
        print(f"  已见类准确率: {seen_acc:.4f} ({seen_correct}/{seen_total})")
        print(f"  未见类准确率: {unseen_acc:.4f} ({unseen_correct}/{unseen_total})")
        print(f"  调和平均值: {H:.4f}")
        print("=" * 50)

        return seen_acc, unseen_acc, H

    def eval_model2(self, att):
        """
        使用 PU 校正后的属性向量评估模型性能，
        调用原始 PSVMA 的 eval_zs_gzsl 接口。
        """
        # 1. 从 att_corrected 中提取 Seen/Unseen 两部分
        #    att_corrected 形状 (C_total, D_attr)
        att_seen_corr = att[self.seenclasses]  # 已见类
        att_unseen_corr = att[self.unseenclasses]  # 未见类

        # # 2. 构建新的 res 字典，替换 att_seen/att_unseen
        new_res = self.res.copy()
        # 这里转到 CPU，因为 build_dataloader 返回的是 CPU Tensor
        new_res['att_seen'] = att_seen_corr.cpu()
        new_res['att_unseen'] = att_unseen_corr.cpu()

        # 3. 调用原始评估接口 eval_zs_gzsl
        #    eval_zs_gzsl(tu_loader, ts_loader, res, model, gamma, device)
        acc_seen, acc_unseen, H, acc_zs = eval_zs_gzsl(
            self.tu_loader,
            self.ts_loader,
            new_res,
            self.model,
            1.5,
            self.device
        )

        # 4. 打印并返回结果
        print('After attribute correction → '
              f'zsl={acc_zs:.4f}, gzsl: seen={acc_seen:.4f}, '
              f'unseen={acc_unseen:.4f}, H={H:.4f}')
        return acc_seen, acc_unseen, H, acc_zs

    def train_and_evaluate(self, datasets="CUB", epochs=500, batch_size=256, lr=5e-5):
        """训练和评估流程"""
        print("提取特征中...")
        # train_features, test_seen_features, test_unseen_features = self.extract_features()
        # torch.save({
        #     'train': train_features,
        #     'seen': test_seen_features,
        #     'unseen': test_unseen_features
        # }, 'cached_features_CUB_test.pt')
        # print("加载缓存特征...")
        cached = torch.load(f'cached_features_{datasets}.pt')
        train_embeds, test_seen_embeds, test_unseen_embeds = cached['train'], cached['seen'], cached['unseen']

        self.min = 0.3
        self.max = 0.7
        self.alpha = 1.0

        # 计算训练集特征的统计信息（作为目标域）
        train_mean = train_embeds.mean(dim=0, keepdim=True)
        train_std = train_embeds.std(dim=0, keepdim=True) + 1e-8

        # train_embeds, test_seen_features_aligned, test_unseen_features_aligned = self.bidirectional_soft_align(train_embeds, test_seen_embeds, test_unseen_embeds)

        test_seen_features_aligned = self.align_to_train_domain(test_seen_embeds, train_std, train_mean, "已见类测试")
        test_unseen_features_aligned = self.align_to_train_domain(test_unseen_embeds, train_std, train_mean, "未见类测试")
        # test_seen_features_aligned = self.advanced_align_to_train_domain(
        #     test_seen_embeds, train_embeds, "已见类测试")
        # test_unseen_features_aligned = self.advanced_align_to_train_domain(
        #     test_unseen_embeds, train_embeds, "未见类测试")
        all_test_logits = self.compute_logits(test_seen_features_aligned, test_unseen_features_aligned)

        all_test_embeds = torch.cat([test_seen_features_aligned, test_unseen_features_aligned], dim=0)

        print("开始PU学习训练...")
        self.pu_learner.optimize(
            train_embeds,  # 使用全部训练特征
            all_test_embeds,  # 使用全部测试特征作为无标签数据
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            weight_decay=1e-5
        )
        # save_model_Prob(self, datasets)
        # load_model_Prob(self, datasets)

        optimal_threshold = visualize_pu_scores(self, all_test_embeds, self.all_test_labels, f"{datasets}_Prob")
        pu_scores = self.evaluate_pu(optimal_threshold, all_test_embeds, self.all_test_labels)
        corr_att = self.correct_att(pu_scores, all_test_embeds, all_test_logits)
        self.eval_model(corr_att, self.all_test_labels, all_test_embeds, pu_scores)

        # self.eval_model2(self.att)

        # 可视化分类分数分布
        # visualize_distributions(self, train_embeds, test_seen_embeds, test_unseen_embeds, dataset=f"{datasets}_Prob")
        # visualize_distributions2(self, train_embeds, test_seen_features_aligned, test_unseen_features_aligned, dataset=f"{datasets}_Prob")
