from PUModels.ProbPULearner_X import PULearner
# from util.helper_func import eval_zs_gzsl
from report_model import *
from visual import *


class PUTrainer:
    def __init__(self, model, dataloader, device, lambda_prob_pu=1.0):
        self.threshold = None
        self.pred_seen_indices = None
        self.pred_unseen_indices = None
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.dataloader.batch_size = 50

        # 获取类别-属性矩阵
        self.att_mat = self.model.att.detach().cpu().numpy()
        feature_dim = model.att.size(1)

        # 存储类别信息
        self.seenclasses = dataloader.seenclasses
        self.unseenclasses = dataloader.unseenclasses

        self.len_test_unseen = None
        self.len_test_seen = None

        self.pu_learner = PULearner(
            feature_dim=feature_dim,
            att_mat=self.att_mat,
            device=self.device,
            lambda_prob_pu=1.0
        )

    def extract_features(self):
        """提取训练集和测试集的特征向量，并进行域适应对齐"""
        with torch.no_grad():
            # 提取训练集特征
            train_features = []
            for i in range(0, self.dataloader.ntrain, self.dataloader.batch_size):
                batch_label, batch_feature, batch_att = self.dataloader.next_batch(
                    self.dataloader.batch_size)
                batch_feature = batch_feature.to(self.device)
                out_package = self.model(batch_feature)
                train_features.append(out_package['embed'])
            train_features = torch.cat(train_features, dim=0)

            # 提取测试集特征
            test_seen_features = []
            test_unseen_features = []
            test_seen_logits = []
            test_unseen_logits = []
            # 已知类测试样本
            test_seen_feature = self.dataloader.data['test_seen']['resnet_features']
            for i in range(0, test_seen_feature.size(0), self.dataloader.batch_size):
                batch_feature = test_seen_feature[i:i + self.dataloader.batch_size].to(self.device)
                out_package = self.model(batch_feature)
                test_seen_features.append(out_package['embed'])
                test_seen_logits.append(out_package['pred'])

            test_seen_features = torch.cat(test_seen_features, dim=0)

            # 未知类测试样本
            test_unseen_feature = self.dataloader.data['test_unseen']['resnet_features']
            for i in range(0, test_unseen_feature.size(0), self.dataloader.batch_size):
                batch_feature = test_unseen_feature[i:i + self.dataloader.batch_size].to(self.device)
                out_package = self.model(batch_feature)
                test_unseen_features.append(out_package['embed'])
                test_unseen_logits.append(out_package['pred'])
            test_unseen_features = torch.cat(test_unseen_features, dim=0)

        all_test_logits = torch.cat(test_seen_logits + test_unseen_logits, dim=0)

        # ===== 核心改进：自适应强度的特征域适应对齐 =====
        print("执行自适应强度的特征域适应对齐...")

        # 计算训练集特征的统计信息（作为目标域）
        train_mean = train_features.mean(dim=0, keepdim=True)
        train_std = train_features.std(dim=0, keepdim=True) + 1e-8

        # 对测试集特征进行域适应对齐
        def align_to_train_domain(features, feature_type="", adaptive_alpha=True):
            # 计算当前特征的统计信息
            current_mean = features.mean(dim=0, keepdim=True)
            current_std = features.std(dim=0, keepdim=True) + 1e-8

            # 标准化到标准分布
            normalized = (features - current_mean) / current_std

            # 对齐到训练域分布
            aligned = normalized * train_std + train_mean

            # 自适应对齐强度：根据对齐需求调整强度
            if adaptive_alpha:
                # 计算对齐需求：距离目标越远，对齐强度越大
                mean_distance = torch.abs(current_mean - train_mean).mean().item()
                base_alpha = 0.7
                # 对于SUN这样的复杂数据集，未见类需要更强对齐
                if "未见类" in feature_type and mean_distance > 0.1:
                    alpha = min(0.9, base_alpha + mean_distance * 2.0)  # 最高0.9
                else:
                    alpha = base_alpha
            else:
                alpha = 0.7

            final_features = alpha * aligned + (1 - alpha) * features

            print(
                f"  {feature_type}特征对齐完成 - 原始均值: {current_mean.mean().item():.4f}, 对齐后均值: {final_features.mean().item():.4f}, 对齐强度: {alpha:.2f}")

            return final_features

        # 对已见类和未见类测试特征分别进行分位数对齐
        test_seen_features_aligned = align_to_train_domain(test_seen_features, "已见类测试")
        test_unseen_features_aligned = align_to_train_domain(test_unseen_features, "未见类测试")

        print(f"分位数域适应对齐完成:")
        print(f"  训练集均值: {train_mean.mean().item():.4f}")
        print(f"  已见类测试集对齐后均值: {test_seen_features_aligned.mean().item():.4f}")
        print(f"  未见类测试集对齐后均值: {test_unseen_features_aligned.mean().item():.4f}")

        return train_features, test_seen_features_aligned, test_unseen_features_aligned, all_test_logits


    def evaluate_pu(self, optimal_threshold, all_test_features, all_test_labels):
        """PU评估函数"""

        len_seen = len(self.test_seen_label)
        len_unseen = len(self.test_unseen_label)

        seen_mask = torch.zeros(len_seen + len_unseen, dtype=torch.bool, device=self.device)
        seen_mask[:len_seen] = True
        unseen_mask = ~seen_mask

        print(f"评估样本分布: 已见类={seen_mask.sum()}, 未见类={unseen_mask.sum()}")

        pu_score = self.pu_learner.predict(all_test_features.to(self.device))
        is_seen_pred = (pu_score > optimal_threshold).to(self.device)

        # === 详细的混淆矩阵分析 ===
        seen_mask = seen_mask.to(self.device)
        unseen_mask = unseen_mask.to(self.device)
        is_seen_pred = is_seen_pred.to(self.device)

        # True Positive: 真实已见类 & 预测已见类
        tp = (seen_mask & is_seen_pred).sum().item()
        # False Positive: 真实未见类 & 预测已见类
        fp = (unseen_mask & is_seen_pred).sum().item()
        # True Negative: 真实未见类 & 预测未见类
        tn = (unseen_mask & ~is_seen_pred).sum().item()
        # False Negative: 真实已见类 & 预测未见类
        fn = (seen_mask & ~is_seen_pred).sum().item()

        print(f"混淆矩阵:")
        print(f"  TP (已见→已见): {tp}")
        print(f"  FP (未见→已见): {fp}")
        print(f"  TN (未见→未见): {tn}")
        print(f"  FN (已见→未见): {fn}")

        # 召回率
        seen_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        unseen_recall = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # 精确率
        seen_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        unseen_precision = tn / (tn + fn) if (tn + fn) > 0 else 0.0

        # F1分数
        seen_f1 = 2 * seen_precision * seen_recall / (seen_precision + seen_recall) if (
                                                                                                   seen_precision + seen_recall) > 0 else 0.0
        unseen_f1 = 2 * unseen_precision * unseen_recall / (unseen_precision + unseen_recall) if (
                                                                                                             unseen_precision + unseen_recall) > 0 else 0.0

        # 总体准确率
        accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0.0

        # 谐波平均数
        harmonic = 2 * seen_recall * unseen_recall / (seen_recall + unseen_recall) if (
                                                                                                  seen_recall + unseen_recall) > 0 else 0.0

        print(f"\n=== PU分类评估结果 ===")
        print(f"最佳阈值: {optimal_threshold:.4f}")
        print(f"\n【召回率 (Recall)】")
        print(f"  已见类召回率: {seen_recall:.4f} (TP Rate)")
        print(f"  未见类召回率: {unseen_recall:.4f} (TN Rate)")
        print(f"\n【精确率 (Precision)】")
        print(f"  已见类精确率: {seen_precision:.4f}")
        print(f"  未见类精确率: {unseen_precision:.4f}")
        print(f"\n【F1分数】")
        print(f"  已见类F1: {seen_f1:.4f}")
        print(f"  未见类F1: {unseen_f1:.4f}")
        print(f"\n【综合指标】")
        print(f"  总体准确率: {accuracy:.4f}")
        print(f"  谐波平均数: {harmonic:.4f}")

        return pu_score

    def train_and_evaluate(self, epochs=300, batch_size=256, lr=5e-5):
        print("提取特征中...")
        train_embeds, test_seen_embeds, test_unseen_embeds, all_test_logits = self.extract_features()
        # 合并所有测试特征和标签
        all_test_embeds = torch.cat([test_seen_embeds, test_unseen_embeds], dim=0)

        self.test_seen_label = self.dataloader.data['test_seen']['labels']
        self.test_unseen_label = self.dataloader.data['test_unseen']['labels']

        all_test_labels = torch.cat([
            self.test_seen_label,
            self.test_unseen_label
        ], dim=0).to(self.device)

        train_labels = self.dataloader.data['train_seen']['labels']

        self.len_test_seen = len(test_seen_embeds)  # CUB:1764  SUN:2580
        self.len_test_unseen = len(test_unseen_embeds)  # CUB:2967  SUN:1440

        print("开始PU学习训练...")

        self.pu_learner.optimize(
            train_embeds,
            all_test_embeds,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr
        )

        # save_model_Prob(self,"CUB")

        optimal_threshold = visualize_pu_scores(self, all_test_embeds, all_test_labels, "CUB")
        pu_scores = self.evaluate_pu(optimal_threshold, all_test_embeds, all_test_labels)

        visualize_distributions(self, train_embeds, test_seen_embeds, test_unseen_embeds, "CUB")