import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm


# class ProbabilisticPULearner(nn.Module):
#     """改进后的基于概率论的PU学习器，支持未标记样本正负分类"""
#
#     def __init__(self, feature_dim, device,
#                  lambda_pu=1.0, lambda_sup=1.0, dropout_rate=0.3, margin=2.0):
#         super().__init__()
#         # f1: 拟合 P(a=1 | y=1, s)
#         self.f1_net = nn.Sequential(
#             nn.Linear(feature_dim, 512),
#             nn.LayerNorm(512),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(dropout_rate),
#             nn.Linear(512, 256),
#             nn.LayerNorm(256),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(dropout_rate),
#             nn.Linear(256, 128),
#             nn.LayerNorm(128),
#             nn.LeakyReLU(0.2),
#             nn.Linear(128, 1)
#         )
#         # f2: 判别样本是否为已见正/负
#         self.f2_net = nn.Sequential(
#             nn.Linear(feature_dim, 512),
#             nn.LayerNorm(512),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(dropout_rate),
#             nn.Linear(512, 256),
#             nn.LayerNorm(256),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(dropout_rate),
#             nn.Linear(256, 128),
#             nn.LayerNorm(128),
#             nn.LeakyReLU(0.2),
#             nn.Linear(128, 1)
#         )
#
#         self.device = device
#         self.lambda_pu = lambda_pu      # PU 损失权重
#         self.lambda_sup = lambda_sup    # 监督分类损失权重
#         self.margin = margin
#         self.bce = nn.BCEWithLogitsLoss()
#
#     def forward(self, p_features, q_features):
#         # 归一化输入特征
#         p_features = F.normalize(p_features, dim=1)
#         q_features = F.normalize(q_features, dim=1)
#
#         # 1. PU概率论损失
#         # 正样本 PU
#         p_a1 = torch.sigmoid(self.f1_net(p_features))
#         p_y1 = torch.sigmoid(self.f2_net(p_features))
#         loss_P = -torch.log(p_a1 * p_y1 + 1e-8).mean()
#         # 未标记 PU
#         p_not_a1 = 1 - torch.sigmoid(self.f1_net(q_features))
#         p_y1_U = torch.sigmoid(self.f2_net(q_features))
#         p_y_neg1_U = 1 - p_y1_U
#         joint = p_not_a1 * p_y1_U + p_y_neg1_U
#         loss_U = -torch.log(joint + 1e-8).mean()
#
#         # 2. 监督分类损失：将已见类正/负分别当作真实标签
#         # 正样本标签为1，未标记当负例标签为0
#         logits_p = self.f2_net(p_features).squeeze()
#         logits_q = self.f2_net(q_features).squeeze()
#         labels_p = torch.ones_like(logits_p)
#         labels_q = torch.zeros_like(logits_q)
#         # sup_loss = self.bce(logits_p, labels_p) + self.bce(logits_q, labels_q)
#
#         weights = p_not_a1 / (p_not_a1.sum() + 1e-8)
#
#         sup_logits_q = self.f2_net(q_features).squeeze()  # shape: [256]
#         targets_q = torch.zeros_like(sup_logits_q)  # shape: [256]
#         sup_loss_q = (weights.squeeze() * F.binary_cross_entropy_with_logits(
#             sup_logits_q, targets_q, reduction='none')).sum()
#         sup_loss = F.binary_cross_entropy_with_logits(
#             self.f2_net(p_features).squeeze(),
#             torch.ones(p_features.size(0), device=self.device)
#         ) + sup_loss_q
#
#         # 3. Margin约束（可选）
#         pos_margin = F.relu(self.margin - logits_p)
#         neg_margin = F.relu(logits_q + self.margin)
#         w_p = F.softmax(pos_margin, dim=0)
#         w_q = F.softmax(neg_margin, dim=0)
#         margin_loss = (pos_margin * w_p).sum() + (neg_margin * w_q).sum()
#
#         # 全部加权
#         total_loss = self.lambda_pu * (loss_P + loss_U) \
#                      + self.lambda_sup * sup_loss \
#                      + 0.5 * margin_loss
#         return total_loss
#
#     def predict(self, features):
#         """返回未标记样本属于正类的概率分数"""
#         self.eval()
#         features = F.normalize(features.to(self.device), dim=1)
#         with torch.no_grad():
#             logits = self.f2_net(features)
#         return logits.squeeze()
#
#
# class PULearner:
#     def __init__(self, feature_dim, att_mat, device, lambda_prob_pu=1.0):
#         self.device = device
#         self.feature_dim = feature_dim
#
#         # 损失函数权重
#         self.lambda_prob_pu = lambda_prob_pu
#
#         # 存储类别-属性矩阵
#         self.att_mat = torch.tensor(att_mat, device=device)
#
#         # 初始化模型组件
#         self.prob_pu_learner = ProbabilisticPULearner(feature_dim, device, lambda_prob_pu).to(device)
#
#     def compute_total_loss(self, p_features, q_features):
#         # 计算概率论PU学习损失
#         prob_pu_loss = self.prob_pu_learner(p_features, q_features)
#
#         # 总损失
#         total_loss = self.lambda_prob_pu * prob_pu_loss
#
#         return total_loss, {
#             'total': total_loss.item()
#         }
#
#     def optimize(self, p_features, q_features, epochs=100, batch_size=256, lr=1e-3,
#                  weight_decay=1e-5, early_stop=30, verbose=True):
#         """训练PU学习模型"""
#         # 确保两个数据集具有相同的样本数
#         global best_state, loss_components
#         min_samples = min(p_features.size(0), q_features.size(0))
#         p_features = p_features[:min_samples]
#         q_features = q_features[:min_samples]
#
#         # 准备数据
#         p_features = F.normalize(p_features, dim=1)
#         q_features = F.normalize(q_features, dim=1)
#
#         # 设置优化器
#         optimizer = Adam([
#             {'params': self.prob_pu_learner.parameters(), 'lr': lr}
#         ], weight_decay=weight_decay)
#
#         # 使用学习率预热 + 余弦退火调度
#         def lr_lambda(epoch):
#             warmup_epochs = 10
#             if epoch < warmup_epochs:
#                 return epoch / warmup_epochs
#             else:
#                 return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
#
#         # 学习率调度
#         scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
#
#         # 训练跟踪
#         best_loss = float('inf')
#         best_epoch = 0
#         loss_history = []
#
#         # 批次生成函数
#         def get_batch_indices(n, batch_size):
#             all_indices = torch.randperm(n)
#             for i in range(0, n, batch_size):
#                 end = min(i + batch_size, n)
#                 yield all_indices[i:end]
#
#         # 主训练循环
#         for epoch in tqdm(range(epochs), desc="训练进度"):
#             self.prob_pu_learner.train()
#
#             epoch_losses = []
#
#             # 获取当前轮次的批次数
#             p_indices = list(get_batch_indices(p_features.size(0), batch_size))
#
#             # 对每个批次进行训练
#             for i in range(len(p_indices)):
#                 p_idx = p_indices[i]
#                 # 使用相同索引确保批次匹配
#                 q_idx = p_idx
#
#                 p_batch = p_features[p_idx].to(self.device)
#                 q_batch = q_features[q_idx].to(self.device)
#
#                 # 计算损失
#                 optimizer.zero_grad()
#                 loss, loss_components = self.compute_total_loss(p_batch, q_batch)
#
#                 # 反向传播
#                 loss.backward()
#
#                 # 梯度裁剪
#                 torch.nn.util.clip_grad_norm_(self.prob_pu_learner.parameters(), 1.0)  # 添加对PU学习器的梯度裁剪
#
#                 # 参数更新
#                 optimizer.step()
#
#                 epoch_losses.append(loss.item())
#
#             # 更新学习率
#             scheduler.step()
#
#             # 计算平均损失
#             avg_loss = np.mean(epoch_losses)
#             loss_history.append(avg_loss)
#
#             # 输出训练状态
#             if verbose and (epoch % 30 == 0 or epoch == epochs - 1):
#                 print(f"\n训练轮次 {epoch + 1}/{epochs}:")
#                 print(f"  总损失: {avg_loss:.4f}")
#                 print(f"  学习率: {scheduler.get_last_lr()[0]:.6f}")
#
#                 # 计算当前分类分数
#                 with torch.no_grad():
#                     p_scores = self.prob_pu_learner.predict(p_features[:min(1000, p_features.size(0))].to(self.device))
#                     q_scores = self.prob_pu_learner.predict(q_features[:min(1000, q_features.size(0))].to(self.device))
#                     print(f"  正类分数 - 均值: {p_scores.mean().item():.4f}, 标准差: {p_scores.std().item():.4f}")
#                     print(f"  负类分数 - 均值: {q_scores.mean().item():.4f}, 标准差: {q_scores.std().item():.4f}")
#
#             # 早停检查
#             if avg_loss < best_loss:
#                 best_loss = avg_loss
#                 best_epoch = epoch
#                 # 保存最佳模型
#                 best_state = {
#                     'prob_pu_learner': self.prob_pu_learner.state_dict(),
#                 }
#             elif epoch - best_epoch > early_stop:
#                 print(f"\n提前停止于轮次 {epoch + 1}, 最佳轮次: {best_epoch + 1}")
#                 break
#
#         # 恢复最佳模型
#         self.prob_pu_learner.load_state_dict(best_state['prob_pu_learner'])
#
#         return loss_history
#
#     def predict(self, features):
#         # 预测样本的分类分数
#         features = F.normalize(features, dim=1)
#         with torch.no_grad():
#             scores = self.prob_pu_learner.predict(features)
#         return scores


class ProbabilisticPULearner(nn.Module):
    """改进后的基于概率论的PU学习器，支持伪标签和负例挖掘"""

    def __init__(self, feature_dim, device,
                 lambda_pu=1.0, lambda_sup=1.0, dropout_rate=0.3, margin=2.0):
        super().__init__()
        # f1: 拟合 P(a=1 | y=1, s)
        self.f1_net = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.LayerNorm(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1)
        )
        # f2: 判别样本是否为已见正/负
        self.f2_net = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.LayerNorm(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1)
        )

        self.device = device
        self.lambda_pu = lambda_pu
        self.lambda_sup = lambda_sup
        self.margin = margin
        self.bce = nn.BCEWithLogitsLoss()
        # 存储伪标签信息
        self.pseudo_pos_mask = None
        self.pseudo_neg_mask = None

    def generate_pseudo_labels(self, q_features, tau_high=2.0, tau_low=-2.0):
        """基于 raw logits 做高/低阈值伪标签"""
        self.eval()
        q_features = F.normalize(q_features, dim=1)
        with torch.no_grad():
            # 使用f2网络预测分数
            logits_q = self.f2_net(q_features).squeeze()
            # 高分样本被视为伪正例
            pos_mask = logits_q >= tau_high
            # 低分样本被视为伪负例
            neg_mask = logits_q <= tau_low
        return pos_mask, neg_mask

    def forward(self, p_features, q_features, epoch=0):
        # 归一化输入特征
        p_features = F.normalize(p_features, dim=1)
        q_features = F.normalize(q_features, dim=1)

        # 1. 基础PU概率论损失
        # 正样本PU
        p_a1 = torch.sigmoid(self.f1_net(p_features))
        p_y1 = torch.sigmoid(self.f2_net(p_features))
        loss_P = -torch.log(p_a1 * p_y1 + 1e-8).mean()

        # 未标记PU
        p_not_a1 = 1 - torch.sigmoid(self.f1_net(q_features))
        p_y1_U = torch.sigmoid(self.f2_net(q_features))
        p_y_neg1_U = 1 - p_y1_U
        joint = p_not_a1 * p_y1_U + p_y_neg1_U
        loss_U = -torch.log(joint + 1e-8).mean()

        # 2. 基于样本难度的监督损失
        # 计算负例可信度
        neg_weights = p_not_a1.view(-1)
        # 设置阈值，只取top 20%的易负例
        if q_features.size(0) > 1:  # 确保有足够样本计算分位数
            threshold = torch.quantile(neg_weights, 0.8)
            easy_neg_mask = neg_weights >= threshold

            # 对易负例进行监督
            sup_logits_q = self.f2_net(q_features).squeeze()
            targets_q = torch.zeros_like(sup_logits_q)

            # 使用伪标签（如果有）
            if self.pseudo_pos_mask is not None and self.pseudo_neg_mask is not None:
                # 将当前批次与整体伪标签对齐
                batch_indices = torch.arange(q_features.size(0), device=self.device)

                # 使用伪正例进行监督（有伪标签才计算）
                if self.pseudo_pos_mask.sum() > 0:
                    pos_indices = batch_indices[self.pseudo_pos_mask[:q_features.size(0)]]
                    if len(pos_indices) > 0:
                        pos_targets = torch.ones_like(sup_logits_q[pos_indices])
                        pos_loss = F.binary_cross_entropy_with_logits(
                            sup_logits_q[pos_indices], pos_targets)
                    else:
                        pos_loss = 0.0
                else:
                    pos_loss = 0.0

                # 结合伪标签和易负例进行监督
                # 排除被标为伪正例的样本
                final_neg_mask = easy_neg_mask & ~self.pseudo_pos_mask[:q_features.size(0)]
                if final_neg_mask.sum() > 0:
                    neg_loss = F.binary_cross_entropy_with_logits(
                        sup_logits_q[final_neg_mask], targets_q[final_neg_mask])
                else:
                    neg_loss = 0.0

                sup_loss = pos_loss + neg_loss if isinstance(pos_loss, float) else pos_loss + neg_loss
            else:
                # 没有伪标签时只用易负例
                if easy_neg_mask.sum() > 0:
                    sup_loss = F.binary_cross_entropy_with_logits(
                        sup_logits_q[easy_neg_mask], targets_q[easy_neg_mask])
                else:
                    sup_loss = torch.tensor(0.0, device=self.device)
        else:
            # 样本太少，使用全部样本
            sup_loss = F.binary_cross_entropy_with_logits(
                self.f2_net(p_features).squeeze(),
                torch.ones(p_features.size(0), device=self.device)
            )

        # 3. Margin约束
        logits_p = self.f2_net(p_features).squeeze()
        logits_q = self.f2_net(q_features).squeeze()
        pos_margin = F.relu(self.margin - logits_p)
        neg_margin = F.relu(logits_q + self.margin)

        if p_features.size(0) > 1 and q_features.size(0) > 1:
            w_p = F.softmax(pos_margin, dim=0)
            w_q = F.softmax(neg_margin, dim=0)
            margin_loss = (pos_margin * w_p).sum() + (neg_margin * w_q).sum()
        else:
            margin_loss = pos_margin.mean() + neg_margin.mean()

        # 总损失
        total_loss = self.lambda_pu * (loss_P + loss_U) + 0*self.lambda_sup * sup_loss + 0 * margin_loss
        return total_loss

    def predict(self, features):
        """返回未标记样本属于正类的概率分数"""
        self.eval()
        features = F.normalize(features.to(self.device), dim=1)
        with torch.no_grad():
            logits = self.f2_net(features)
        return logits.squeeze()


class PULearner:
    def __init__(self, feature_dim, att_mat, device, lambda_prob_pu=1.0):
        self.device = device
        self.feature_dim = feature_dim

        # 损失函数权重
        self.lambda_prob_pu = lambda_prob_pu

        # 存储类别-属性矩阵
        self.att_mat = torch.tensor(att_mat, device=device)

        # 初始化模型组件
        self.prob_pu_learner = ProbabilisticPULearner(feature_dim, device, lambda_prob_pu).to(device)

    def compute_total_loss(self, p_features, q_features):
        # 计算概率论PU学习损失
        prob_pu_loss = self.prob_pu_learner(p_features, q_features)

        # 总损失
        total_loss = self.lambda_prob_pu * prob_pu_loss

        return total_loss, {
            'total': total_loss.item()
        }

    def optimize(self, p_features, q_features, epochs=100, batch_size=256, lr=1e-3,
                 weight_decay=1e-5, early_stop=30, verbose=True):
        """改进的训练PU学习模型方法，支持伪标签"""
        # 确保两个数据集具有相同的样本数
        min_samples = min(p_features.size(0), q_features.size(0))
        p_features = p_features[:min_samples]
        q_features = q_features[:min_samples]

        # 准备数据
        p_features = F.normalize(p_features, dim=1)
        q_features = F.normalize(q_features, dim=1)

        # 设置优化器
        optimizer = Adam([
            {'params': self.prob_pu_learner.parameters(), 'lr': lr}
        ], weight_decay=weight_decay)

        # 使用学习率预热 + 余弦退火调度
        def lr_lambda(epoch):
            warmup_epochs = 10
            if epoch < warmup_epochs:
                return epoch / warmup_epochs
            else:
                return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))

        # 学习率调度
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        # 训练跟踪
        best_loss = float('inf')
        best_epoch = 0
        loss_history = []

        # 伪标签参数
        pseudo_label_start = 30  # 伪标签开始的轮次
        pseudo_label_interval = 10  # 伪标签更新间隔
        tau_high_start = 2.0  # 初始高阈值
        tau_low_start = -2.0  # 初始低阈值
        tau_high_end = 1.0  # 最终高阈值（逐渐降低要求）
        tau_low_end = -1.0  # 最终低阈值（逐渐提高要求）

        # 批次生成函数
        def get_batch_indices(n, batch_size):
            all_indices = torch.randperm(n)
            for i in range(0, n, batch_size):
                end = min(i + batch_size, n)
                yield all_indices[i:end]

        # 主训练循环
        for epoch in tqdm(range(epochs), desc="训练进度"):
            self.prob_pu_learner.train()

            # 更新伪标签（从第pseudo_label_start轮开始，每pseudo_label_interval轮更新一次）
            if epoch >= pseudo_label_start and epoch % pseudo_label_interval == 0:
                # 动态调整阈值，随着训练进行逐渐放宽正例标准，提高负例标准
                progress = min(1.0, (epoch - pseudo_label_start) / (epochs - pseudo_label_start))
                tau_high = tau_high_start - progress * (tau_high_start - tau_high_end)
                tau_low = tau_low_start + progress * (tau_low_end - tau_low_start)

                # 生成全局伪标签
                pos_mask, neg_mask = self.prob_pu_learner.generate_pseudo_labels(
                    q_features.to(self.device), tau_high=tau_high, tau_low=tau_low)
                self.prob_pu_learner.pseudo_pos_mask = pos_mask
                self.prob_pu_learner.pseudo_neg_mask = neg_mask

                if verbose:
                    print(f"\n轮次 {epoch + 1}: 更新伪标签")
                    print(
                        f"  高置信正例: {pos_mask.sum().item()}/{len(pos_mask)} ({pos_mask.float().mean().item() * 100:.2f}%)")
                    print(
                        f"  高置信负例: {neg_mask.sum().item()}/{len(neg_mask)} ({neg_mask.float().mean().item() * 100:.2f}%)")
                    print(f"  当前阈值: 正例>{tau_high:.2f}, 负例<{tau_low:.2f}")

            epoch_losses = []

            # 获取当前轮次的批次数
            p_indices = list(get_batch_indices(p_features.size(0), batch_size))

            # 对每个批次进行训练
            for i in range(len(p_indices)):
                p_idx = p_indices[i]
                # 使用相同索引确保批次匹配
                q_idx = p_idx

                p_batch = p_features[p_idx].to(self.device)
                q_batch = q_features[q_idx].to(self.device)

                # 计算损失
                optimizer.zero_grad()
                loss = self.prob_pu_learner(p_batch, q_batch, epoch=epoch)

                # 反向传播
                loss.backward()

                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(self.prob_pu_learner.parameters(), 1.0)

                # 参数更新
                optimizer.step()

                epoch_losses.append(loss.item())

            # 更新学习率
            scheduler.step()

            # 计算平均损失
            avg_loss = np.mean(epoch_losses)
            loss_history.append(avg_loss)

            # 输出训练状态
            if verbose and (epoch % 30 == 0 or epoch == epochs - 1):
                print(f"\n训练轮次 {epoch + 1}/{epochs}:")
                print(f"  总损失: {avg_loss:.4f}")
                print(f"  学习率: {scheduler.get_last_lr()[0]:.6f}")

                # 计算当前分类分数
                with torch.no_grad():
                    p_scores = self.prob_pu_learner.predict(p_features[:min(1000, p_features.size(0))].to(self.device))
                    q_scores = self.prob_pu_learner.predict(q_features[:min(1000, q_features.size(0))].to(self.device))
                    print(f"  正类分数 - 均值: {p_scores.mean().item():.4f}, 标准差: {p_scores.std().item():.4f}")
                    print(f"  负类分数 - 均值: {q_scores.mean().item():.4f}, 标准差: {q_scores.std().item():.4f}")

            # 早停检查
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_epoch = epoch
                # 保存最佳模型
                best_state = {
                    'prob_pu_learner': self.prob_pu_learner.state_dict(),
                }
            elif epoch - best_epoch > early_stop:
                print(f"\n提前停止于轮次 {epoch + 1}, 最佳轮次: {best_epoch + 1}")
                break

        # 恢复最佳模型
        self.prob_pu_learner.load_state_dict(best_state['prob_pu_learner'])

        return loss_history

    def predict(self, features):
        # 预测样本的分类分数
        features = F.normalize(features, dim=1)
        with torch.no_grad():
            scores = self.prob_pu_learner.predict(features)
        return scores

