import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm


class ProbabilisticPULearner(nn.Module):
    """基于ZSL语义向量的LBE-PU学习器"""

    def __init__(self, feature_dim, device,
                 lambda_pu=1.0, lambda_sup=1.0, lambda_consistency=0.3,
                 dropout_rate=0.3, margin=2.0):
        super().__init__()

        self.f1_net = nn.Sequential(
            nn.Linear(feature_dim, 384),
            nn.LayerNorm(384),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(384, 192),
            nn.LayerNorm(192),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(192, 96),
            nn.LayerNorm(96),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(96, 48),
            nn.LayerNorm(48),
            nn.LeakyReLU(0.2),
            nn.Linear(48, 1)
        )

        self.f2_net = nn.Sequential(
            nn.Linear(feature_dim, 640),
            nn.LayerNorm(640),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(640, 320),
            nn.LayerNorm(320),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_rate),
            nn.Linear(320, 160),
            nn.LayerNorm(160),
            nn.LeakyReLU(0.2),
            nn.Linear(160, 1)
        )

        self.device = device
        self.lambda_pu = lambda_pu
        self.lambda_sup = lambda_sup
        self.lambda_consistency = lambda_consistency
        self.margin = margin
        self.bce = nn.BCEWithLogitsLoss()

        self.register_buffer('pi_estimate', torch.tensor(0.5))
        self.register_buffer('pi_momentum', torch.tensor(0.9))
        self.pi_min = 0.1
        self.pi_max = 0.9

        self.pseudo_pos_mask = None
        self.pseudo_neg_mask = None

    def update_class_prior(self, q_features):
        """动态更新类别先验π"""
        self.eval()
        with torch.no_grad():
            q_features_norm = F.normalize(q_features, dim=1)
            # 使用f2网络预测测试集中已见类的比例
            logits_q = self.f2_net(q_features_norm).squeeze()
            prob_seen = torch.sigmoid(logits_q)

            # 当前批次的先验估计
            current_pi = prob_seen.mean().clamp(self.pi_min, self.pi_max)

            # 指数移动平均更新
            self.pi_estimate = self.pi_momentum * self.pi_estimate + (1 - self.pi_momentum) * current_pi
            self.pi_estimate = self.pi_estimate.clamp(self.pi_min, self.pi_max)

        self.train()
        return self.pi_estimate.item()

    def compute_semantic_consistency_loss(self, features):
        """
        计算语义一致性损失
        """
        p_a1_given_y1 = torch.sigmoid(self.f1_net(features))
        p_y1 = torch.sigmoid(self.f2_net(features))

        consistency_loss = torch.mean((p_a1_given_y1 - p_y1) ** 2)

        return consistency_loss

    def compute_semantic_coherence_loss(self, p_features, q_features):
        """
        计算语义连贯性损失
        """
        p_logits = self.f2_net(p_features)
        q_logits = self.f2_net(q_features)

        p_mean = torch.mean(p_logits)
        q_mean = torch.mean(q_logits)

        inter_class_separation = F.relu(2.0 - (p_mean - q_mean))

        # p_var = torch.var(p_logits)
        # q_var = torch.var(q_logits)
        # intra_class_compactness = p_var + q_var
        intra_class_compactness = torch.var(p_logits)

        coherence_loss = inter_class_separation + 0.1 * intra_class_compactness

        return coherence_loss

    def compute_pu_loss_with_prior(self, p_features, q_features):
        """PU损失计算"""
        p_features = F.normalize(p_features, dim=1)
        q_features = F.normalize(q_features, dim=1)

        # 正样本PU损失
        p_a1 = torch.sigmoid(self.f1_net(p_features))
        p_y1 = torch.sigmoid(self.f2_net(p_features))
        loss_P_positive = -torch.log(p_a1 * p_y1 + 1e-8).mean()

        # 未标记样本损失
        p_not_a1 = 1 - torch.sigmoid(self.f1_net(q_features))
        p_y1_U = torch.sigmoid(self.f2_net(q_features))
        p_y_neg1_U = 1 - p_y1_U

        pi = self.pi_estimate

        # 未标记样本中的正类损失（加权）
        loss_U_positive = -torch.log(p_not_a1 * p_y1_U + 1e-8).mean()

        # 未标记样本中的负类损失
        loss_U_negative = -torch.log(p_y_neg1_U + 1e-8).mean()

        # 基于先验的PU损失组合
        loss_PU = pi * loss_P_positive + (1 - pi) * loss_U_negative + pi * loss_U_positive

        return loss_PU, pi

    def generate_pseudo_labels(self, q_features, tau_high=2.0, tau_low=-2.0):
        """伪标签生成函数"""
        self.eval()
        q_features = F.normalize(q_features, dim=1)
        with torch.no_grad():
            logits_q = self.f2_net(q_features).squeeze()

            pi = self.pi_estimate.item()

            adjusted_tau_high = tau_high + (pi - 0.5) * 0.5
            adjusted_tau_low = tau_low - (pi - 0.5) * 0.5

            pos_mask = logits_q >= adjusted_tau_high
            neg_mask = logits_q <= adjusted_tau_low

        self.train()
        return pos_mask, neg_mask

    def forward(self, p_features, q_features, epoch=0):
        if epoch % 10 == 0:
            current_pi = self.update_class_prior(q_features)
        else:
            current_pi = self.pi_estimate.item()

        p_features = F.normalize(p_features, dim=1)
        q_features = F.normalize(q_features, dim=1)

        # PU损失
        loss_PU, pi_used = self.compute_pu_loss_with_prior(p_features, q_features)

        # 语义一致性约束损失
        consistency_loss_p = self.compute_semantic_consistency_loss(p_features)
        consistency_loss_q = self.compute_semantic_consistency_loss(q_features)
        total_consistency_loss = consistency_loss_p + 0.3 * consistency_loss_q

        # 语义连贯性损失
        coherence_loss = self.compute_semantic_coherence_loss(p_features, q_features)

        # 基于先验调整的监督损失
        p_not_a1 = 1 - torch.sigmoid(self.f1_net(q_features))
        neg_weights = p_not_a1.view(-1)

        quantile_threshold = 0.7 + 0.2 * pi_used

        if q_features.size(0) > 1:
            threshold = torch.quantile(neg_weights, quantile_threshold)
            easy_neg_mask = neg_weights >= threshold

            if easy_neg_mask.sum() > 0:
                sup_logits_q = self.f2_net(q_features).squeeze()
                targets_q = torch.zeros_like(sup_logits_q)
                sup_loss_q = F.binary_cross_entropy_with_logits(
                    sup_logits_q[easy_neg_mask], targets_q[easy_neg_mask])
            else:
                sup_loss_q = torch.tensor(0.0, device=self.device)
        else:
            sup_loss_q = torch.tensor(0.0, device=self.device)

        pos_sup_loss = F.binary_cross_entropy_with_logits(
            self.f2_net(p_features).squeeze(),
            torch.ones(p_features.size(0), device=self.device)
        )

        sup_loss = pos_sup_loss + sup_loss_q

        # Margin约束(工程优化)
        logits_p = self.f2_net(p_features).squeeze()
        logits_q = self.f2_net(q_features).squeeze()
        pos_margin = F.relu(self.margin - logits_p)
        neg_margin = F.relu(logits_q + self.margin)

        if p_features.size(0) > 1 and q_features.size(0) > 1:
            w_p = F.softmax(pos_margin, dim=0)
            w_q = F.softmax(neg_margin, dim=0)
            margin_loss = (pos_margin * w_p).sum() + (neg_margin * w_q).sum()
        else:
            margin_loss = pos_margin.mean() + neg_margin.mean()

        # 总损失
        pi_confidence = 1 - 2 * abs(pi_used - 0.5)

        pu_weight = self.lambda_pu * (0.5 + 0.5 * pi_confidence)
        sup_weight = self.lambda_sup * (1.5 - 0.5 * pi_confidence)

        total_loss = (pu_weight * loss_PU +
                      sup_weight * sup_loss +
                      self.lambda_consistency * total_consistency_loss +
                      0.5 * coherence_loss +
                      0.5 * margin_loss)

        return total_loss

    def predict(self, features):
        """预测函数"""
        self.eval()
        features = F.normalize(features.to(self.device), dim=1)
        with torch.no_grad():
            f1_score = torch.sigmoid(self.f1_net(features))
            f2_score = torch.sigmoid(self.f2_net(features))
            combined_score = 0.5 * f2_score + 0.5 * f1_score
            combined_logit = torch.log(combined_score / (1 - combined_score + 1e-8))
        return combined_logit.squeeze()

    def get_prior_info(self):
        return {
            'pi_estimate': self.pi_estimate.item(),
            'pi_confidence': 1 - 2 * abs(self.pi_estimate.item() - 0.5)
        }


class PULearner:
    def __init__(self, feature_dim, att_mat, device, lambda_prob_pu=1.0, lambda_consistency=0.5):
        self.device = device
        self.feature_dim = feature_dim

        self.lambda_prob_pu = lambda_prob_pu

        # 存储类别-属性矩阵
        self.att_mat = torch.tensor(att_mat, device=device)

        # 初始化模型组件
        self.prob_pu_learner = ProbabilisticPULearner(
            feature_dim, device,
            lambda_prob_pu,
            lambda_consistency=lambda_consistency
        ).to(device)

    def compute_total_loss(self, p_features, q_features, epoch=0):
        # 计算概率论PU学习损失
        prob_pu_loss = self.prob_pu_learner(p_features, q_features, epoch)

        total_loss = self.lambda_prob_pu * prob_pu_loss

        return total_loss, {
            'total': total_loss.item()
        }

    def optimize(self, p_features, q_features, epochs=120, batch_size=256, lr=1e-3,
                 weight_decay=1e-5, early_stop=35, verbose=True):
        """训练优化函数"""

        min_samples = min(p_features.size(0), q_features.size(0))
        p_features = p_features[:min_samples]
        q_features = q_features[:min_samples]

        p_features = F.normalize(p_features, dim=1)
        q_features = F.normalize(q_features, dim=1)

        optimizer = Adam([
            {'params': self.prob_pu_learner.f1_net.parameters(), 'lr': lr * 0.8},  # f1略慢学习
            {'params': self.prob_pu_learner.f2_net.parameters(), 'lr': lr}  # f2正常学习
        ], weight_decay=weight_decay)

        def lr_lambda(epoch):
            warmup_epochs = 12
            if epoch < warmup_epochs:
                return epoch / warmup_epochs
            else:
                return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        best_loss = float('inf')
        best_epoch = 0
        loss_history = []

        pseudo_label_start = 35
        pseudo_label_interval = 8
        tau_high_start = 2.2
        tau_low_start = -2.2
        tau_high_end = 1.0
        tau_low_end = -1.0

        def get_batch_indices(n, batch_size):
            all_indices = torch.randperm(n)
            for i in range(0, n, batch_size):
                end = min(i + batch_size, n)
                yield all_indices[i:end]

        # 主训练循环
        for epoch in tqdm(range(epochs), desc="训练进度"):
            self.prob_pu_learner.train()

            if epoch >= pseudo_label_start and epoch % pseudo_label_interval == 0:
                progress = min(1.0, (epoch - pseudo_label_start) / (epochs - pseudo_label_start))
                tau_high = tau_high_start - progress * (tau_high_start - tau_high_end)
                tau_low = tau_low_start + progress * (tau_low_end - tau_low_start)

                pos_mask, neg_mask = self.prob_pu_learner.generate_pseudo_labels(
                    q_features.to(self.device), tau_high=tau_high, tau_low=tau_low)
                self.prob_pu_learner.pseudo_pos_mask = pos_mask
                self.prob_pu_learner.pseudo_neg_mask = neg_mask

                if verbose:
                    prior_info = self.prob_pu_learner.get_prior_info()
                    print(f"\n轮次 {epoch + 1}: 更新伪标签")
                    print(
                        f"  高置信正例: {pos_mask.sum().item()}/{len(pos_mask)} ({pos_mask.float().mean().item() * 100:.2f}%)")
                    print(
                        f"  高置信负例: {neg_mask.sum().item()}/{len(neg_mask)} ({neg_mask.float().mean().item() * 100:.2f}%)")
                    print(f"  当前阈值: 正例>{tau_high:.2f}, 负例<{tau_low:.2f}")
                    print(f"  先验估计π: {prior_info['pi_estimate']:.3f}, 置信度: {prior_info['pi_confidence']:.3f}")

            epoch_losses = []

            p_indices = list(get_batch_indices(p_features.size(0), batch_size))

            for i in range(len(p_indices)):
                p_idx = p_indices[i]
                q_idx = p_idx

                p_batch = p_features[p_idx].to(self.device)
                q_batch = q_features[q_idx].to(self.device)

                optimizer.zero_grad()
                loss, _ = self.compute_total_loss(p_batch, q_batch, epoch)

                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.prob_pu_learner.parameters(), 1.0)

                optimizer.step()

                epoch_losses.append(loss.item())

            scheduler.step()

            avg_loss = np.mean(epoch_losses)
            loss_history.append(avg_loss)

            if verbose and (epoch % 25 == 0 or epoch == epochs - 1):
                print(f"\n训练轮次 {epoch + 1}/{epochs}:")
                print(f"  总损失: {avg_loss:.4f}")
                print(f"  学习率: {scheduler.get_last_lr()[0]:.6f}")

                with torch.no_grad():
                    p_sample = p_features[:min(1000, p_features.size(0))].to(self.device)
                    q_sample = q_features[:min(1000, q_features.size(0))].to(self.device)

                    p_scores = self.prob_pu_learner.predict(p_sample)
                    q_scores = self.prob_pu_learner.predict(q_sample)

                    p_f1 = torch.sigmoid(self.prob_pu_learner.f1_net(F.normalize(p_sample, dim=1)))
                    p_f2 = torch.sigmoid(self.prob_pu_learner.f2_net(F.normalize(p_sample, dim=1)))
                    consistency_score = 1 - torch.mean(torch.abs(p_f1 - p_f2)).item()

                    print(f"  正类分数 - 均值: {p_scores.mean().item():.4f}, 标准差: {p_scores.std().item():.4f}")
                    print(f"  负类分数 - 均值: {q_scores.mean().item():.4f}, 标准差: {q_scores.std().item():.4f}")
                    print(f"  分离度: {(p_scores.mean() - q_scores.mean()).item():.4f}")
                    print(f"  f1-f2一致性: {consistency_score:.4f}")

            if avg_loss < best_loss:
                best_loss = avg_loss
                best_epoch = epoch
                best_state = {
                    'prob_pu_learner': self.prob_pu_learner.state_dict(),
                }
            elif epoch - best_epoch > early_stop:
                print(f"\n提前停止于轮次 {epoch + 1}, 最佳轮次: {best_epoch + 1}")
                break

        self.prob_pu_learner.load_state_dict(best_state['prob_pu_learner'])

        return loss_history

    def predict(self, features):
        # 预测样本的分类分数
        features = F.normalize(features, dim=1)
        with torch.no_grad():
            scores = self.prob_pu_learner.predict(features)
        return scores