import torch
from tqdm import tqdm
from PUModels.ProbPULearner_Z import PULearner
from helper_func import eval_zs_gzsl
from util.report_model import *
from util.visual import *


class PUTrainer:
    def __init__(self, model, dataloader, device,lambda_prob_pu=1.0):
        self.threshold = None
        self.pred_seen_indices = None
        self.pred_unseen_indices = None
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.dataloader.batch_size = 50

        # 获取类别-属性矩阵
        self.att_mat = self.model.att.detach().cpu().numpy()
        feature_dim = model.att.size(1)

        # 存储类别信息
        self.seenclasses = dataloader.seenclasses
        self.unseenclasses = dataloader.unseenclasses

        self.len_test_unseen = None
        self.len_test_seen = None

        self.pu_learner = PULearner(
            feature_dim=feature_dim,
            att_mat=self.att_mat,
            device=self.device,
            lambda_prob_pu=1.0
        )

    def extract_features(self):
        """提取训练集和测试集的特征向量，并进行域适应对齐"""
        with torch.no_grad():
            # 提取训练集特征
            train_features = []
            for i in range(0, self.dataloader.ntrain, self.dataloader.batch_size):
                batch_label, batch_feature, batch_att = self.dataloader.next_batch(
                    self.dataloader.batch_size)
                batch_feature = batch_feature.to(self.device)
                out_package = self.model(batch_feature)
                train_features.append(out_package['embed'])
            train_features = torch.cat(train_features, dim=0)

            # 提取测试集特征
            test_seen_features = []
            test_unseen_features = []
            test_seen_logits = []
            test_unseen_logits = []
            # 已知类测试样本
            test_seen_feature = self.dataloader.data['test_seen']['resnet_features']
            for i in range(0, test_seen_feature.size(0), self.dataloader.batch_size):
                batch_feature = test_seen_feature[i:i + self.dataloader.batch_size].to(self.device)
                out_package = self.model(batch_feature)
                test_seen_features.append(out_package['embed'])
                test_seen_logits.append(out_package['pred'])

            test_seen_features = torch.cat(test_seen_features, dim=0)

            # 未知类测试样本
            test_unseen_feature = self.dataloader.data['test_unseen']['resnet_features']
            for i in range(0, test_unseen_feature.size(0), self.dataloader.batch_size):
                batch_feature = test_unseen_feature[i:i + self.dataloader.batch_size].to(self.device)
                out_package = self.model(batch_feature)
                test_unseen_features.append(out_package['embed'])
                test_unseen_logits.append(out_package['pred'])
            test_unseen_features = torch.cat(test_unseen_features, dim=0)

        all_test_logits = torch.cat(test_seen_logits + test_unseen_logits, dim=0)

        print("执行自适应强度的特征域适应对齐...")

        # 计算训练集特征的统计信息（作为目标域）
        train_mean = train_features.mean(dim=0, keepdim=True)
        train_std = train_features.std(dim=0, keepdim=True) + 1e-8

        # 对测试集特征进行域适应对齐
        def align_to_train_domain(features, feature_type="", adaptive_alpha=True):
            # 计算当前特征的统计信息
            current_mean = features.mean(dim=0, keepdim=True)
            current_std = features.std(dim=0, keepdim=True) + 1e-8

            # 标准化到标准分布
            normalized = (features - current_mean) / current_std

            # 对齐到训练域分布
            aligned = normalized * train_std + train_mean

            # 自适应对齐强度：根据对齐需求调整强度
            if adaptive_alpha:
                # 计算对齐需求：距离目标越远，对齐强度越大
                mean_distance = torch.abs(current_mean - train_mean).mean().item()
                base_alpha = 0.7
                if "未见类" in feature_type and mean_distance > 0.1:
                    alpha = min(0.9, base_alpha + mean_distance * 2.0)  # 最高0.9
                else:
                    alpha = base_alpha
            else:
                alpha = 0.7

            final_features = alpha * aligned + (1 - alpha) * features

            print(
                f"  {feature_type}特征对齐完成 - 原始均值: {current_mean.mean().item():.4f}, 对齐后均值: {final_features.mean().item():.4f}, 对齐强度: {alpha:.2f}")

            return final_features

        # 对已见类和未见类测试特征分别进行分位数对齐
        test_seen_features_aligned = align_to_train_domain(test_seen_features, "已见类测试")
        test_unseen_features_aligned = align_to_train_domain(test_unseen_features, "未见类测试")

        print(f"分位数域适应对齐完成:")
        print(f"  训练集均值: {train_mean.mean().item():.4f}")
        print(f"  已见类测试集对齐后均值: {test_seen_features_aligned.mean().item():.4f}")
        print(f"  未见类测试集对齐后均值: {test_unseen_features_aligned.mean().item():.4f}")

        return train_features, test_seen_features_aligned, test_unseen_features_aligned, all_test_logits


    def evaluate_pu(self, optimal_threshold, all_test_features, all_test_labels):
        """PU评估函数"""
        # 构建标签掩码
        len_seen = len(self.test_seen_label)
        len_unseen = len(self.test_unseen_label)

        seen_mask = torch.zeros(len_seen + len_unseen, dtype=torch.bool, device=self.device)
        seen_mask[:len_seen] = True
        unseen_mask = ~seen_mask

        print(f"评估样本分布: 已见类={seen_mask.sum()}, 未见类={unseen_mask.sum()}")

        pu_score = self.pu_learner.predict(all_test_features.to(self.device))
        is_seen_pred = (pu_score > optimal_threshold).to(self.device)

        seen_mask = seen_mask.to(self.device)
        unseen_mask = unseen_mask.to(self.device)
        is_seen_pred = is_seen_pred.to(self.device)

        # True Positive: 真实已见类 & 预测已见类
        tp = (seen_mask & is_seen_pred).sum().item()
        # False Positive: 真实未见类 & 预测已见类
        fp = (unseen_mask & is_seen_pred).sum().item()
        # True Negative: 真实未见类 & 预测未见类
        tn = (unseen_mask & ~is_seen_pred).sum().item()
        # False Negative: 真实已见类 & 预测未见类
        fn = (seen_mask & ~is_seen_pred).sum().item()

        print(f"混淆矩阵:")
        print(f"  TP (已见→已见): {tp}")
        print(f"  FP (未见→已见): {fp}")
        print(f"  TN (未见→未见): {tn}")
        print(f"  FN (已见→未见): {fn}")

        # 召回率
        seen_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        unseen_recall = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # 精确率
        seen_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        unseen_precision = tn / (tn + fn) if (tn + fn) > 0 else 0.0

        # F1分数
        seen_f1 = 2 * seen_precision * seen_recall / (seen_precision + seen_recall) if (
                                                                                                   seen_precision + seen_recall) > 0 else 0.0
        unseen_f1 = 2 * unseen_precision * unseen_recall / (unseen_precision + unseen_recall) if (
                                                                                                             unseen_precision + unseen_recall) > 0 else 0.0

        # 总体准确率
        accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0.0

        # 谐波平均数
        harmonic = 2 * seen_recall * unseen_recall / (seen_recall + unseen_recall) if (
                                                                                                seen_recall + unseen_recall) > 0 else 0.0

        print(f"\n=== PU分类评估结果 ===")
        print(f"最佳阈值: {optimal_threshold:.4f}")
        print(f"\n【召回率 (Recall)】")
        print(f"  已见类召回率: {seen_recall:.4f} (TP Rate)")
        print(f"  未见类召回率: {unseen_recall:.4f} (TN Rate)")
        print(f"\n【精确率 (Precision)】")
        print(f"  已见类精确率: {seen_precision:.4f}")
        print(f"  未见类精确率: {unseen_precision:.4f}")
        print(f"\n【F1分数】")
        print(f"  已见类F1: {seen_f1:.4f}")
        print(f"  未见类F1: {unseen_f1:.4f}")
        print(f"\n【综合指标】")
        print(f"  总体准确率: {accuracy:.4f}")
        print(f"  谐波平均数: {harmonic:.4f}")

        return pu_score

    def correct_att(self, pu_scores, all_test_embeds, all_test_logits):
        """基于PU学习的语义属性向量校正"""
        self.model.eval()
        print("\n" + "=" * 50)
        print("开始执行无标签的语义校准评估...")

        with torch.no_grad():
            # 提取属性矩阵
            att = self.model.att.clone().to("cpu")
            batch_size = 128

            # 分析PU分数分布
            pu_mean, pu_std = pu_scores.mean().item(), pu_scores.std().item()
            print(f"PU分数分布 - 均值={pu_mean:.4f}, 标准差={pu_std:.4f}")

            # 使用均值作为基本阈值
            threshold = pu_mean

            self.threshold = threshold

            if len(self.seenclasses) >= 300:
                threshold = threshold - pu_std * 0.5

            print(f"PU分类阈值: {threshold:.4f}")

            # 基本分类
            pred_seen_mask = pu_scores > threshold
            pred_unseen_mask = ~pred_seen_mask

            # 设置高置信度阈值
            # high_conf_seen_threshold = threshold - 0.3 * pu_std
            # high_conf_unseen_threshold = threshold + 0.3 * pu_std
            high_conf_seen_threshold = -0.57
            high_conf_unseen_threshold = -1.17

            print(f"高置信度已见类阈值: {high_conf_seen_threshold:.4f}")
            print(f"高置信度未见类阈值: {high_conf_unseen_threshold:.4f}")

            # 筛选高置信度样本
            high_conf_seen_mask = (pu_scores > high_conf_seen_threshold).flatten()
            high_conf_unseen_mask = (pu_scores < high_conf_unseen_threshold).flatten()

            # 统计样本数量
            total_seen = pred_seen_mask.sum().item()
            total_unseen = pred_unseen_mask.sum().item()
            high_conf_seen = high_conf_seen_mask.sum().item()
            high_conf_unseen = high_conf_unseen_mask.sum().item()

            print(f"分类统计:")
            print(f"  总样本数: {len(pu_scores)}")
            print(
                f"  预测为已见类: {total_seen} (其中高置信度: {high_conf_seen}, {high_conf_seen / total_seen * 100:.1f}%)")
            print(
                f"  预测为未见类: {total_unseen} (其中高置信度: {high_conf_unseen}, {high_conf_unseen / total_unseen * 100:.1f}%)")

            att_corrected = att.clone()

            print("为校准获取伪标签...")

            # 获取已见类和未见类的伪标签
            self.pred_seen_indices = torch.nonzero(pred_seen_mask).squeeze()
            self.pred_unseen_indices = torch.nonzero(pred_unseen_mask).squeeze()

            # 获取高置信度的已见类和未见类样本索引
            high_conf_seen_indices = torch.nonzero(high_conf_seen_mask).squeeze()
            high_conf_unseen_indices = torch.nonzero(high_conf_unseen_mask).squeeze()

            min_sample = 2 if len(self.seenclasses) <= 300 else 2

            # 处理高置信度已见类样本
            if high_conf_seen > 0:
                print(f"处理 {high_conf_seen} 个高置信度的已见类样本")

                # 存储每个类别的样本嵌入和计数
                class_embeddings = {cls.item(): [] for cls in self.seenclasses}
                class_counts = {cls.item(): 0 for cls in self.seenclasses}

                # 分批处理数据
                for i in range(0, len(high_conf_seen_indices), batch_size):
                    end = min(i + batch_size, len(high_conf_seen_indices))
                    batch_indices = high_conf_seen_indices[i:end]

                    # 逐个处理样本
                    for idx in batch_indices:
                        embedding = all_test_embeds[idx:idx + 1].cpu()
                        logits = all_test_logits[idx:idx + 1].cpu()

                        # 只考虑已见类中的最高分类
                        seen_scores = logits[0, self.seenclasses]
                        max_seen_idx = torch.argmax(seen_scores)

                        # pseudo_label：该PU预测为已见类样本在原ZSL模型中的类别伪标签
                        pseudo_label = self.seenclasses[max_seen_idx].item()

                        # 存储样本嵌入和更新计数
                        class_embeddings[pseudo_label].append(embedding[0].numpy())
                        class_counts[pseudo_label] += 1

                # 现在使用累积的嵌入进行已见类语义校准
                print("使用高置信度伪标签校准已见类语义...")

                seen_calibrated_count = 0
                for cls in self.seenclasses.cpu().numpy():
                    sample_count = class_counts[cls]

                    if sample_count >= min_sample:
                        cls_embeds = np.array(class_embeddings[cls])
                        cls_centroid = np.mean(cls_embeds, axis=0)

                        # 计算原始语义与聚类质心的相似度
                        original_att = att[cls].cpu().numpy()
                        cos_sim = np.dot(original_att, cls_centroid) / (
                                np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))

                        # 自适应权重 - 相似度越高，保留越多原始语义
                        alpha = max(0.99, min(0.99, cos_sim))

                        # 校正语义
                        corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
                        corrected_att = corrected_att / np.linalg.norm(corrected_att)

                        # 更新属性矩阵
                        att_corrected[cls] = torch.from_numpy(corrected_att).float()
                        seen_calibrated_count += 1

                        print(
                            f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")

            # 处理高置信度未见类样本
            if high_conf_unseen > 0:
                print(f"处理 {high_conf_unseen} 个高置信度的已见类样本")

                # 存储每个类别的样本嵌入和计数
                class_embeddings = {cls.item(): [] for cls in self.unseenclasses}
                class_counts = {cls.item(): 0 for cls in self.unseenclasses}

                # 分批处理数据
                for i in range(0, len(high_conf_unseen_indices), batch_size):
                    end = min(i + batch_size, len(high_conf_unseen_indices))
                    batch_indices = high_conf_unseen_indices[i:end]

                    # 逐个处理样本
                    for idx in batch_indices:
                        embedding = all_test_embeds[idx:idx + 1].cpu()
                        logits = all_test_logits[idx:idx + 1].cpu()

                        # 只考虑未见类中的最高分类
                        unseen_scores = logits[0, self.unseenclasses]
                        max_unseen_idx = torch.argmax(unseen_scores)

                        # pseudo_label：该PU预测为已见类样本在原ZSL模型中的类别伪标签
                        pseudo_label = self.unseenclasses[max_unseen_idx].item()

                        # 存储样本嵌入和更新计数
                        class_embeddings[pseudo_label].append(embedding[0].numpy())
                        class_counts[pseudo_label] += 1

                # 现在使用累积的嵌入进行已见类语义校准
                print("使用高置信度伪标签校准未见类语义...")

                unseen_calibrated_count = 0
                for cls in self.unseenclasses.cpu().numpy():
                    sample_count = class_counts[cls]

                    if sample_count >= min_sample:
                        cls_embeds = np.array(class_embeddings[cls])
                        cls_centroid = np.mean(cls_embeds, axis=0)

                        # 计算原始语义与聚类质心的相似度
                        original_att = att[cls].cpu().numpy()
                        cos_sim = np.dot(original_att, cls_centroid) / (
                                np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))

                        # 自适应权重 - 相似度越高，保留越多原始语义
                        alpha = max(0.99, min(0.99, cos_sim))

                        # 校正语义
                        corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
                        corrected_att = corrected_att / np.linalg.norm(corrected_att)

                        # 更新属性矩阵
                        att_corrected[cls] = torch.from_numpy(corrected_att).float()
                        unseen_calibrated_count += 1

                        print(
                            f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")

                if unseen_calibrated_count == 0:
                    print("  没有足够的高置信度样本进行已见类校准")

            att_corrected = att_corrected.to(self.device)
            return att_corrected

    # def correct_att(self, pu_scores, all_test_embeds, all_test_logits):
    #     """基于PU学习的语义属性向量校正"""
    #     self.model.eval()
    #     print("\n" + "=" * 50)
    #     print("开始执行无标签的语义校准评估...")
    #
    #     with torch.no_grad():
    #         # 提取属性矩阵
    #         att = self.model.att.clone().to("cpu")
    #         batch_size = 128
    #
    #         # 分析PU分数分布
    #         pu_mean, pu_std = pu_scores.mean().item(), pu_scores.std().item()
    #         print(f"PU分数分布 - 均值={pu_mean:.4f}, 标准差={pu_std:.4f}")
    #
    #         # 使用均值作为基本阈值
    #         # threshold = pu_mean + pu_std * 0.2
    #         threshold = pu_mean
    #
    #         self.threshold = threshold
    #
    #         # if len(self.seenclasses) >= 300:
    #         #     threshold = threshold - pu_std * 0.5
    #
    #         print(f"PU分类阈值: {threshold:.4f}")
    #
    #         # 基本分类
    #         pred_seen_mask = pu_scores > threshold
    #         pred_unseen_mask = ~pred_seen_mask
    #
    #         # ========== 网格搜索最佳阈值 ==========
    #         print("开始网格搜索最佳高置信度阈值...")
    #
    #         # 定义搜索范围
    #         seen_min = threshold - 1.0 * pu_std
    #         seen_max = threshold + 1.0 * pu_std
    #         unseen_min = threshold - 1.0 * pu_std
    #         unseen_max = threshold + 1.0 * pu_std
    #         step = 0.1
    #
    #         # 生成搜索网格
    #         seen_thresholds = np.arange(seen_min, seen_max + step, step)
    #         unseen_thresholds = np.arange(unseen_min, unseen_max + step, step)
    #
    #         best_H = 0.0
    #         best_seen_threshold = threshold - 0.2 * pu_std
    #         best_unseen_threshold = threshold + 0.1 * pu_std
    #
    #         print(f"搜索范围: seen[{seen_min:.2f}, {seen_max:.2f}], unseen[{unseen_min:.2f}, {unseen_max:.2f}]")
    #         print(
    #             f"总计搜索组合: {len(seen_thresholds)} x {len(unseen_thresholds)} = {len(seen_thresholds) * len(unseen_thresholds)}")
    #
    #         total_combinations = len(seen_thresholds) * len(unseen_thresholds)
    #
    #         with tqdm(total=total_combinations, desc="网格搜索进度", ncols=100) as pbar:
    #             for seen_thresh in seen_thresholds:
    #                 for unseen_thresh in unseen_thresholds:
    #                     # 计算当前阈值组合的H值
    #                     H = self._evaluate_threshold_combination(
    #                         pu_scores, seen_thresh, unseen_thresh, all_test_embeds, all_test_logits, att
    #                     )
    #
    #                     # 更新最佳结果
    #                     if H > best_H:
    #                         best_H = H
    #                         best_seen_threshold = seen_thresh
    #                         best_unseen_threshold = unseen_thresh
    #                         # 更新进度条描述显示当前最佳结果
    #                         pbar.set_postfix({
    #                             'Best_H': f'{best_H:.4f}',
    #                             'Best_Seen': f'{best_seen_threshold:.3f}',
    #                             'Best_Unseen': f'{best_unseen_threshold:.3f}'
    #                         })
    #
    #                     # 更新进度条
    #                     pbar.update(1)
    #
    #         print(f"网格搜索完成！")
    #         print(f"最佳组合: seen_threshold={best_seen_threshold:.4f}, unseen_threshold={best_unseen_threshold:.4f}")
    #         print(f"最佳H值: {best_H:.4f}")
    #
    #         # 使用最佳阈值
    #         high_conf_seen_threshold = best_seen_threshold
    #         high_conf_unseen_threshold = best_unseen_threshold
    #         # ========== 网格搜索结束 ==========
    #
    #         print(f"高置信度已见类阈值: {high_conf_seen_threshold:.4f}")
    #         print(f"高置信度未见类阈值: {high_conf_unseen_threshold:.4f}")
    #
    #         # 筛选高置信度样本
    #         high_conf_seen_mask = (pu_scores > high_conf_seen_threshold).flatten()
    #         high_conf_unseen_mask = (pu_scores < high_conf_unseen_threshold).flatten()
    #
    #         # 统计样本数量
    #         total_seen = pred_seen_mask.sum().item()
    #         total_unseen = pred_unseen_mask.sum().item()
    #         high_conf_seen = high_conf_seen_mask.sum().item()
    #         high_conf_unseen = high_conf_unseen_mask.sum().item()
    #
    #         print(f"分类统计:")
    #         print(f"  总样本数: {len(pu_scores)}")
    #         print(
    #             f"  预测为已见类: {total_seen} (其中高置信度: {high_conf_seen}, {high_conf_seen / total_seen * 100:.1f}%)")
    #         print(
    #             f"  预测为未见类: {total_unseen} (其中高置信度: {high_conf_unseen}, {high_conf_unseen / total_unseen * 100:.1f}%)")
    #
    #         att_corrected = att.clone()
    #
    #         print("为校准获取伪标签...")
    #
    #         # 获取已见类和未见类的伪标签
    #         self.pred_seen_indices = torch.nonzero(pred_seen_mask).squeeze()
    #         self.pred_unseen_indices = torch.nonzero(pred_unseen_mask).squeeze()
    #
    #         # 获取高置信度的已见类和未见类样本索引
    #         high_conf_seen_indices = torch.nonzero(high_conf_seen_mask).squeeze()
    #         high_conf_unseen_indices = torch.nonzero(high_conf_unseen_mask).squeeze()
    #
    #         min_sample = 2 if len(self.seenclasses) <= 300 else 2
    #
    #         # 处理高置信度已见类样本
    #         if high_conf_seen > 0:
    #             print(f"处理 {high_conf_seen} 个高置信度的已见类样本")
    #
    #             # 存储每个类别的样本嵌入和计数
    #             class_embeddings = {cls.item(): [] for cls in self.seenclasses}
    #             class_counts = {cls.item(): 0 for cls in self.seenclasses}
    #
    #             # 分批处理数据
    #             for i in range(0, len(high_conf_seen_indices), batch_size):
    #                 end = min(i + batch_size, len(high_conf_seen_indices))
    #                 batch_indices = high_conf_seen_indices[i:end]
    #
    #                 # 逐个处理样本
    #                 for idx in batch_indices:
    #                     embedding = all_test_embeds[idx:idx + 1].cpu()
    #                     logits = all_test_logits[idx:idx + 1].cpu()
    #
    #                     # 只考虑已见类中的最高分类
    #                     seen_scores = logits[0, self.seenclasses]
    #                     max_seen_idx = torch.argmax(seen_scores)
    #
    #                     # pseudo_label：该PU预测为已见类样本在原ZSL模型中的类别伪标签
    #                     pseudo_label = self.seenclasses[max_seen_idx].item()
    #
    #                     # 存储样本嵌入和更新计数
    #                     class_embeddings[pseudo_label].append(embedding[0].numpy())
    #                     class_counts[pseudo_label] += 1
    #
    #             # 现在使用累积的嵌入进行已见类语义校准
    #             print("使用高置信度伪标签校准已见类语义...")
    #
    #             seen_calibrated_count = 0
    #             for cls in self.seenclasses.cpu().numpy():
    #                 sample_count = class_counts[cls]
    #
    #                 if sample_count >= min_sample:
    #                     cls_embeds = np.array(class_embeddings[cls])
    #                     cls_centroid = np.mean(cls_embeds, axis=0)
    #
    #                     # 计算原始语义与聚类质心的相似度
    #                     original_att = att[cls].cpu().numpy()
    #                     cos_sim = np.dot(original_att, cls_centroid) / (
    #                             np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))
    #
    #                     # 自适应权重 - 相似度越高，保留越多原始语义
    #                     alpha = max(self.min, min(self.max, cos_sim))
    #
    #                     # 校正语义
    #                     corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
    #                     corrected_att = corrected_att / np.linalg.norm(corrected_att)
    #
    #                     # 更新属性矩阵
    #                     att_corrected[cls] = torch.from_numpy(corrected_att).float()
    #                     seen_calibrated_count += 1
    #
    #                     print(
    #                         f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")
    #
    #         # 处理高置信度未见类样本
    #         if high_conf_unseen > 0:
    #             print(f"处理 {high_conf_unseen} 个高置信度的未见类样本")
    #
    #             # 存储每个类别的样本嵌入和计数
    #             class_embeddings = {cls.item(): [] for cls in self.unseenclasses}
    #             class_counts = {cls.item(): 0 for cls in self.unseenclasses}
    #
    #             # 分批处理数据
    #             for i in range(0, len(high_conf_unseen_indices), batch_size):
    #                 end = min(i + batch_size, len(high_conf_unseen_indices))
    #                 batch_indices = high_conf_unseen_indices[i:end]
    #
    #                 # 逐个处理样本
    #                 for idx in batch_indices:
    #                     embedding = all_test_embeds[idx:idx + 1].cpu()
    #                     logits = all_test_logits[idx:idx + 1].cpu()
    #
    #                     # 只考虑未见类中的最高分类
    #                     unseen_scores = logits[0, self.unseenclasses]
    #                     max_unseen_idx = torch.argmax(unseen_scores)
    #
    #                     # pseudo_label：该PU预测为未见类样本在原ZSL模型中的类别伪标签
    #                     pseudo_label = self.unseenclasses[max_unseen_idx].item()
    #
    #                     # 存储样本嵌入和更新计数
    #                     class_embeddings[pseudo_label].append(embedding[0].numpy())
    #                     class_counts[pseudo_label] += 1
    #
    #             # 现在使用累积的嵌入进行未见类语义校准
    #             print("使用高置信度伪标签校准未见类语义...")
    #
    #             unseen_calibrated_count = 0
    #             for cls in self.unseenclasses.cpu().numpy():
    #                 sample_count = class_counts[cls]
    #
    #                 if sample_count >= min_sample:
    #                     cls_embeds = np.array(class_embeddings[cls])
    #                     cls_centroid = np.mean(cls_embeds, axis=0)
    #
    #                     # 计算原始语义与聚类质心的相似度
    #                     original_att = att[cls].cpu().numpy()
    #                     cos_sim = np.dot(original_att, cls_centroid) / (
    #                             np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))
    #
    #                     # 自适应权重 - 相似度越高，保留越多原始语义
    #                     alpha = max(self.min, min(self.max, cos_sim))
    #
    #                     # 校正语义
    #                     corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
    #                     corrected_att = corrected_att / np.linalg.norm(corrected_att)
    #
    #                     # 更新属性矩阵
    #                     att_corrected[cls] = torch.from_numpy(corrected_att).float()
    #                     unseen_calibrated_count += 1
    #
    #                     print(
    #                         f"  类别 {cls}: 使用 {sample_count} 个高置信度样本, 权重 α={alpha:.4f}, 相似度={cos_sim:.4f}")
    #
    #             if unseen_calibrated_count == 0:
    #                 print("  没有足够的高置信度样本进行未见类校准")
    #
    #         att_corrected = att_corrected.to(self.device)
    #         return att_corrected

    # def _evaluate_threshold_combination(self, pu_scores, seen_thresh, unseen_thresh, all_test_embeds, all_test_logits,
    #                                     att):
    #     """评估特定阈值组合的H值"""
    #     # 获取高置信度样本
    #     high_conf_seen_mask = (pu_scores > seen_thresh).flatten()
    #     high_conf_unseen_mask = (pu_scores < unseen_thresh).flatten()
    #
    #     high_conf_seen = high_conf_seen_mask.sum().item()
    #     high_conf_unseen = high_conf_unseen_mask.sum().item()
    #
    #     # 如果高置信度样本太少，返回0
    #     if high_conf_seen < 5 and high_conf_unseen < 5:
    #         return 0.0
    #
    #     att_corrected = att.clone()
    #     min_sample = 1 if len(self.seenclasses) <= 300 else 2
    #
    #     # 简化的语义校准过程（只做核心计算）
    #     if high_conf_seen > 0:
    #         high_conf_seen_indices = torch.nonzero(high_conf_seen_mask).squeeze()
    #         class_embeddings = {cls.item(): [] for cls in self.seenclasses}
    #         class_counts = {cls.item(): 0 for cls in self.seenclasses}
    #
    #         for idx in high_conf_seen_indices:
    #             embedding = all_test_embeds[idx:idx + 1].cpu()
    #             logits = all_test_logits[idx:idx + 1].cpu()
    #             seen_scores = logits[0, self.seenclasses]
    #             max_seen_idx = torch.argmax(seen_scores)
    #             pseudo_label = self.seenclasses[max_seen_idx].item()
    #             class_embeddings[pseudo_label].append(embedding[0].numpy())
    #             class_counts[pseudo_label] += 1
    #
    #         for cls in self.seenclasses.cpu().numpy():
    #             sample_count = class_counts[cls]
    #             if sample_count >= min_sample:
    #                 cls_embeds = np.array(class_embeddings[cls])
    #                 cls_centroid = np.mean(cls_embeds, axis=0)
    #                 original_att = att[cls].cpu().numpy()
    #                 cos_sim = np.dot(original_att, cls_centroid) / (
    #                         np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))
    #                 alpha = max(self.min, min(self.max, cos_sim))
    #                 corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
    #                 corrected_att = corrected_att / np.linalg.norm(corrected_att)
    #                 att_corrected[cls] = torch.from_numpy(corrected_att).float()
    #
    #     if high_conf_unseen > 0:
    #         high_conf_unseen_indices = torch.nonzero(high_conf_unseen_mask).squeeze()
    #         class_embeddings = {cls.item(): [] for cls in self.unseenclasses}
    #         class_counts = {cls.item(): 0 for cls in self.unseenclasses}
    #
    #         for idx in high_conf_unseen_indices:
    #             embedding = all_test_embeds[idx:idx + 1].cpu()
    #             logits = all_test_logits[idx:idx + 1].cpu()
    #             unseen_scores = logits[0, self.unseenclasses]
    #             max_unseen_idx = torch.argmax(unseen_scores)
    #             pseudo_label = self.unseenclasses[max_unseen_idx].item()
    #             class_embeddings[pseudo_label].append(embedding[0].numpy())
    #             class_counts[pseudo_label] += 1
    #
    #         for cls in self.unseenclasses.cpu().numpy():
    #             sample_count = class_counts[cls]
    #             if sample_count >= min_sample:
    #                 cls_embeds = np.array(class_embeddings[cls])
    #                 cls_centroid = np.mean(cls_embeds, axis=0)
    #                 original_att = att[cls].cpu().numpy()
    #                 cos_sim = np.dot(original_att, cls_centroid) / (
    #                         np.linalg.norm(original_att) * np.linalg.norm(cls_centroid))
    #                 alpha = max(self.min, min(self.max, cos_sim))
    #                 corrected_att = alpha * original_att + (1 - alpha) * cls_centroid
    #                 corrected_att = corrected_att / np.linalg.norm(corrected_att)
    #                 att_corrected[cls] = torch.from_numpy(corrected_att).float()
    #
    #     # 快速评估H值（简化版本的eval_model）
    #     att_corrected = att_corrected.to(self.device)
    #     all_test_labels = torch.cat([
    #         self.dataloader.data['test_seen']['labels'],
    #         self.dataloader.data['test_unseen']['labels']
    #     ], dim=0).to(self.device)
    #
    #     test_seen_label = all_test_labels[:self.len_test_seen].to("cpu")
    #     test_unseen_label = all_test_labels[self.len_test_seen:].to("cpu")
    #     test_seen_embeds = all_test_embeds[:self.len_test_seen].to("cpu")
    #     test_unseen_embeds = all_test_embeds[self.len_test_seen:].to("cpu")
    #
    #     vec_bias = (self.model.mask_bias * self.model.bias).to(self.device)
    #     alpha = self.alpha
    #
    #     # 评估已见类
    #     seen_correct = 0
    #     seen_total = 0
    #     batch_size = 128
    #
    #     for i in range(0, self.len_test_seen, batch_size):
    #         end = min(i + batch_size, self.len_test_seen)
    #         batch_labels = test_seen_label[i:end].to(self.device)
    #         features_embed = test_seen_embeds[i:end].to(self.device)
    #         similarity = torch.mm(features_embed, att_corrected.t()) + vec_bias
    #
    #         seen_indices = torch.arange(self.len_test_seen)
    #         idx_start = i
    #         idx_end = end
    #         if idx_end > idx_start:
    #             current_pu_scores = pu_scores[seen_indices[idx_start:idx_end]].to(self.device)
    #             seen_weight = torch.tanh(current_pu_scores).unsqueeze(1)
    #
    #             seen_sim = similarity[:, self.seenclasses.to(self.device)]
    #             unseen_sim = similarity[:, self.unseenclasses.to(self.device)]
    #             seen_enhanced = seen_sim * (1.0 + alpha * seen_weight)
    #             unseen_enhanced = unseen_sim * (1.0 - alpha * seen_weight)
    #
    #             enhanced_sim = torch.zeros_like(similarity)
    #             enhanced_sim[:, self.seenclasses.to(self.device)] = seen_enhanced
    #             enhanced_sim[:, self.unseenclasses.to(self.device)] = unseen_enhanced
    #
    #             _, predicted = torch.max(enhanced_sim, dim=1)
    #             seen_correct += (predicted == batch_labels).sum().item()
    #             seen_total += len(batch_labels)
    #
    #     # 评估未见类
    #     unseen_correct = 0
    #     unseen_total = 0
    #
    #     for i in range(0, self.len_test_unseen, batch_size):
    #         end = min(i + batch_size, self.len_test_unseen)
    #         batch_labels = test_unseen_label[i:end].to(self.device)
    #         features_embed = test_unseen_embeds[i:end].to(self.device)
    #         similarity = torch.mm(features_embed, att_corrected.t()) + vec_bias
    #
    #         unseen_indices = torch.arange(self.len_test_unseen) + self.len_test_seen
    #         idx_start = i
    #         idx_end = end
    #         if idx_end > idx_start:
    #             current_pu_scores = pu_scores[unseen_indices[idx_start:idx_end]].to(self.device)
    #             unseen_weight = torch.tanh(current_pu_scores).unsqueeze(1)
    #
    #             seen_sim = similarity[:, self.seenclasses.to(self.device)]
    #             unseen_sim = similarity[:, self.unseenclasses.to(self.device)]
    #             seen_enhanced = seen_sim * (1.0 + alpha * unseen_weight)
    #             unseen_enhanced = unseen_sim * (1.0 - alpha * unseen_weight)
    #
    #             enhanced_sim = torch.zeros_like(similarity)
    #             enhanced_sim[:, self.seenclasses.to(self.device)] = seen_enhanced
    #             enhanced_sim[:, self.unseenclasses.to(self.device)] = unseen_enhanced
    #
    #             _, predicted = torch.max(enhanced_sim, dim=1)
    #             unseen_correct += (predicted == batch_labels).sum().item()
    #             unseen_total += len(batch_labels)
    #
    #     # 计算H值
    #     seen_acc = seen_correct / seen_total if seen_total > 0 else 0
    #     unseen_acc = unseen_correct / unseen_total if unseen_total > 0 else 0
    #     H = (2 * seen_acc * unseen_acc) / (seen_acc + unseen_acc) if seen_acc + unseen_acc > 0 else 0
    #
    #     return H


    def eval_model(self, att_corrected, all_test_labels, all_test_embeds, pu_scores):
        """使用经过PU分类器校正后的语义属性向量来对原始的零样本模型做预测"""
        print("使用校正后的语义进行最终预测...")
        batch_size = 128

        # 获取测试数据
        test_seen_label = all_test_labels[:self.len_test_seen].to("cpu")
        test_unseen_label = all_test_labels[self.len_test_seen:].to("cpu")

        test_seen_embeds = all_test_embeds[:self.len_test_seen].to("cpu")
        test_unseen_embeds = all_test_embeds[self.len_test_seen:].to("cpu")

        # 记录原始样本索引，仅用于最终结果分离
        seen_indices = torch.arange(self.len_test_seen)
        unseen_indices = torch.arange(self.len_test_unseen) + self.len_test_seen

        vec_bias = (self.model.mask_bias * self.model.bias).to(self.device)  # (1, 200)

        # 单独处理已见类和未见类，用于最终评估
        seen_correct = 0
        seen_total = 0

        alpha = self.alpha

        for i in range(0, self.len_test_seen, batch_size):
            end = min(i + batch_size, self.len_test_seen)
            batch_labels = test_seen_label[i:end].to(self.device)
            # 获取特征嵌入
            features_embed = test_seen_embeds[i:end].to(self.device)

            # 计算与所有类别属性的相似度
            similarity = torch.mm(features_embed, att_corrected.t())
            similarity = similarity + vec_bias

            # 获取PU分数，索引对应原始加载顺序
            idx_start = i
            idx_end = min(i + batch_size, self.len_test_seen)
            if idx_end > idx_start:
                current_pu_scores = pu_scores[seen_indices[idx_start:idx_end]].to(self.device)

                seen_weight = torch.sigmoid(current_pu_scores-self.threshold).unsqueeze(1)

                # 分开已见类和未见类的分数
                seen_sim = similarity[:, self.seenclasses.to(self.device)]
                unseen_sim = similarity[:, self.unseenclasses.to(self.device)]

                seen_enhanced = seen_sim * (1.0 + alpha * seen_weight)
                unseen_enhanced = unseen_sim * (1.0 - alpha * seen_weight)

                # 合并分数
                enhanced_sim = torch.zeros_like(similarity)
                enhanced_sim[:, self.seenclasses.to(self.device)] = seen_enhanced
                enhanced_sim[:, self.unseenclasses.to(self.device)] = unseen_enhanced

                # 获取预测标签
                _, predicted = torch.max(enhanced_sim, dim=1)
                seen_correct += (predicted == batch_labels).sum().item()
                seen_total += len(batch_labels)

            torch.cuda.empty_cache()

        # 处理未见类测试集
        unseen_correct = 0
        unseen_total = 0
        for i in range(0, self.len_test_unseen, batch_size):
            end = min(i + batch_size, self.len_test_unseen)
            batch_labels = test_unseen_label[i:end].to(self.device)

            # 获取特征嵌入
            features_embed = test_unseen_embeds[i:end].to(self.device)

            # 计算与所有类别属性的相似度
            similarity = torch.mm(features_embed, att_corrected.t())
            similarity = similarity + vec_bias

            # 获取PU分数，索引对应原始加载顺序
            idx_start = i
            idx_end = min(i + batch_size, self.len_test_unseen)
            if idx_end > idx_start:
                # 调整索引偏移量
                current_pu_scores = pu_scores[unseen_indices[idx_start:idx_end]].to(self.device)

                # 转换PU分数
                unseen_weight = torch.sigmoid(current_pu_scores-self.threshold).unsqueeze(1)

                # 分开已见类和未见类的分数
                seen_sim = similarity[:, self.seenclasses.to(self.device)]
                unseen_sim = similarity[:, self.unseenclasses.to(self.device)]

                seen_enhanced = seen_sim * (1.0 - alpha * unseen_weight)
                unseen_enhanced = unseen_sim * (1.0 + alpha * unseen_weight)

                # 合并分数
                enhanced_sim = torch.zeros_like(similarity)
                enhanced_sim[:, self.seenclasses.to(self.device)] = seen_enhanced
                enhanced_sim[:, self.unseenclasses.to(self.device)] = unseen_enhanced

                # 获取预测标签
                _, predicted = torch.max(enhanced_sim, dim=1)
                unseen_correct += (predicted == batch_labels).sum().item()
                unseen_total += len(batch_labels)

        # 8. 计算最终性能
        seen_acc = seen_correct / seen_total if seen_total > 0 else 0
        unseen_acc = unseen_correct / unseen_total if unseen_total > 0 else 0
        H = (2 * seen_acc * unseen_acc) / (seen_acc + unseen_acc) if seen_acc + unseen_acc > 0 else 0

        print(f"语义校准评估结果:")
        print(f"  已见类准确率: {seen_acc:.4f} ({seen_correct}/{seen_total})")
        print(f"  未见类准确率: {unseen_acc:.4f} ({unseen_correct}/{unseen_total})")
        print(f"  调和平均值: {H:.4f}")
        print("=" * 50)

        return seen_acc, unseen_acc, H


    def train_and_evaluate(self, epochs=300, batch_size=256, lr=5e-5):
        print("提取特征中...")
        train_embeds, test_seen_embeds, test_unseen_embeds, all_test_logits = self.extract_features()
        # 合并所有测试特征和标签
        all_test_embeds = torch.cat([test_seen_embeds, test_unseen_embeds], dim=0)

        self.test_seen_label = self.dataloader.data['test_seen']['labels']
        self.test_unseen_label = self.dataloader.data['test_unseen']['labels']

        all_test_labels = torch.cat([
            self.test_seen_label,
            self.test_unseen_label
        ], dim=0).to(self.device)

        train_labels = self.dataloader.data['train_seen']['labels']

        self.len_test_seen = len(test_seen_embeds)  # CUB:1764  SUN:2580
        self.len_test_unseen = len(test_unseen_embeds)  # CUB:2967  SUN:1440

        self.min = 0.99     # CUB:0.99
        self.max = 0.99     # CUB:0.99
        self.alpha = 0.1

        print("开始PU学习训练...")

        self.pu_learner.optimize(
            train_embeds,
            all_test_embeds,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr
        )
        #
        # save_model_Prob(self, dataset="SUN")
        # load_model_Prob(self, dataset="SUN")

        optimal_threshold = visualize_pu_scores(self, all_test_embeds, all_test_labels,"SUN")
        pu_scores = self.evaluate_pu(optimal_threshold, all_test_embeds, all_test_labels)
        # corr_att = self.correct_att(pu_scores, all_test_embeds, all_test_logits)
        # self.eval_model(corr_att, all_test_labels, all_test_embeds, pu_scores)
        # visualize_distributions(self,train_embeds,test_seen_embeds, test_unseen_embeds,"SUN")