import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
import math


class PULLPLoss(nn.Module):
    def __init__(self, prior, loss, lambda_p=1.0, lambda_u=1.0):
        super(PULLPLoss, self).__init__()
        self.prior = prior
        self.loss_func = loss
        self.lambda_p = lambda_p
        self.lambda_u = lambda_u

    def forward(self, x_p, t_p, u_bags):
        # ---- P 部分 ----
        n_positive = max(1., t_p.sum().item())
        y_p = self.loss_func(x_p)
        positive_risk = torch.sum(1 / n_positive * y_p)

        # ---- U 部分：LLP loss ----
        bce_loss = nn.BCELoss()
        losses_u = []
        for x_u in u_bags:
            probs = torch.sigmoid(x_u)
            pred_prop = probs.mean()
            target = torch.tensor(self.prior, device=pred_prop.device, dtype=pred_prop.dtype)
            losses_u.append(bce_loss(pred_prop, target))
        losses_u = torch.stack(losses_u)
        loss_u = losses_u.mean()

        # ---- 组合总损失 ----
        total_loss = self.lambda_p * positive_risk + self.lambda_u * loss_u

        return total_loss


class nnPUSBloss(nn.Module):
    """Loss function for PUSB learning."""
    def __init__(self, prior, gamma=1, beta=0):
        super(nnPUSBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.positive = 1
        self.unlabeled = -1
        self.eps = 1e-7

    def forward(self, x, t):
        # clip the predict value to make the following optimization problem well-defined.
        x = torch.clamp(x, min=self.eps, max=1 - self.eps)

        t = t[:, None]
        # positive: if positive,1 else 0
        # unlabeled: if unlabeled,1 else 0
        positive, unlabeled = (t == self.positive).float(), (
            t == self.unlabeled).float()
        n_positive, n_unlabeled = max(1.,
                                      positive.sum().item()), max(
                                          1.,
                                          unlabeled.sum().item())
        y_positive = -torch.log(x)
        y_unlabeled = -torch.log(1 - x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (unlabeled / n_unlabeled - self.prior * positive / n_positive) *
            y_unlabeled)

        # print("positive_risk:", positive_risk.item())
        # print("negative_risk:", negative_risk.item())

        objective = positive_risk + negative_risk
        # nnPU learning
        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective
        return x_out


def nnPUSB_loss(x, t, prior):
    """wrapper of loss function for non-negative PU with a select bias learning

        .. math::
            L_[\\-pi E_1[\\log(f(x))]+\\max(-E_X[\\log[1-f(x)]+\\pi E_1[\\log(1-f(x))], \\beta)+ R(f)
    """
    return nnPUSBloss(prior=prior)(x, t)


class nnPUloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True):
        super(nnPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.loss_func = loss
        self.nnpu = nnpu
        self.positive = 1
        self.unlabeled = -1

    def forward(self, x, t):
        t = t[:, None]
        # positive: if positive,1 else 0
        # unlabeled: if unlabeled,1 else 0
        positive, unlabeled = (t == self.positive).float(), (
            t == self.unlabeled).float()
        n_positive, n_unlabeled = max(1.,
                                      positive.sum().item()), max(
                                          1.,
                                          unlabeled.sum().item())
        y_positive = self.loss_func(x)
        y_unlabeled = self.loss_func(-x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (unlabeled / n_unlabeled - self.prior * positive / n_positive) *
            y_unlabeled)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            # nPU learning
            else:
                x_out = objective
        else:
            x_out = objective
        return x_out, negative_risk


def nnPU_loss(x, t, prior, loss, nnpu=True):
    """wrapper of loss function for non-negative PU learning

        .. math::
            L_[\\pi E_1[l(f(x))]+\\max(E_X[l(-f(x))]-\\pi E_1[l(-f(x))], \\beta)
    """
    return nnPUloss(prior=prior, loss=loss, nnpu=nnpu)(x, t)


class PNloss(nn.Module):
    """Loss function for PN learning."""
    def __init__(self, prior, loss):
        super(PNloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.positive = 1
        self.negative = -1

    def forward(self, x, t):
        t = t[:, None]
        # positive: if positive,1 else 0
        # negative: if negative,1 else 0
        positive, negative = (t == self.positive).float(), (
            t == self.negative).float()
        n_positive, n_negative = max(1.,
                                     positive.sum().item()), max(
                                         1.,
                                         negative.sum().item())
        y_positive = self.loss_func(x)
        y_negative = self.loss_func(-x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (1. - self.prior) * negative / n_negative * y_negative)

        objective = positive_risk + negative_risk
        return objective


def PN_loss(x, t, prior, loss):
    """wrapper of loss function for PN learning

        .. math::
            L_[\\pi E_p[l(f(x))]+L_[(1. - \\pi)E_n[l(f(x))]]
    """
    return PNloss(prior=prior, loss=loss)(x, t)



class PUCEloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True, objective=False):
        super(PUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # uPU learning
            x_out = objective
        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class ABSPUloss(nn.Module):
    """Loss function for ABS-PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, objective=False):
        super(ABSPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = positive_risk + negative_risk
        x_out = positive_risk + torch.abs(negative_risk)
        return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class DistPUloss(nn.Module):
    """Loss function for ABS-PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0):
        super(DistPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)
        x_out = 2 * positive_risk + torch.abs(R_u_neg - self.prior)
        return x_out, x_out, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class BasePUCEloss(nn.Module):
    """Base PU learning loss function"""

    def __init__(self, prior, loss_func=None, gamma=1, beta=0, nnpu=False):
        super(BasePUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.loss_func = loss_func if loss_func is not None else nn.Sigmoid()

    def calculate_risks(self, x_p, x_u, t_p, t_u):
        """Calculate different risk components"""
        n_positive = max(1., t_p.sum().item())
        n_unlabeled = max(1., (-t_u).sum().item())

        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        return positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg



class DynamicPUCELoss(nn.Module):
    """Loss function for PU learning with dynamic gamma adjustment."""

    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=False,
                 gamma_min=0.1, gamma_max=5.0, adjustment_rate=0.1):
        super(DynamicPUCELoss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu

        # Parameters for dynamic gamma adjustment
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.adjustment_rate = adjustment_rate
        self.moving_avg_pos_risk = None
        self.moving_avg_neg_risk = None
        self.momentum = 0.9  # Momentum for moving average calculation

    def adjust_gamma(self, positive_risk, negative_risk):
        """Dynamically adjust gamma based on the balance of positive and negative risks."""
        # Update moving averages
        if self.moving_avg_pos_risk is None:
            self.moving_avg_pos_risk = positive_risk.item()
            self.moving_avg_neg_risk = negative_risk.item()
        else:
            self.moving_avg_pos_risk = (self.momentum * self.moving_avg_pos_risk +
                                        (1 - self.momentum) * positive_risk.item())
            self.moving_avg_neg_risk = (self.momentum * self.moving_avg_neg_risk +
                                        (1 - self.momentum) * negative_risk.item())

        # Calculate risk ratio
        risk_ratio = abs(self.moving_avg_pos_risk / (self.moving_avg_neg_risk + 1e-8))

        # Adjust gamma based on risk ratio
        if risk_ratio > 1:
            # Positive risk is larger, increase gamma to give more weight to negative risk
            self.gamma = min(self.gamma_max,
                             self.gamma * (1 + self.adjustment_rate * (risk_ratio - 1)))
        else:
            # Negative risk is larger, decrease gamma to give less weight to negative risk
            self.gamma = max(self.gamma_min,
                             self.gamma / (1 + self.adjustment_rate * (1 / risk_ratio - 1)))

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1., t_p.sum().item()), max(1., (-t_u).sum().item())

        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Adjust gamma based on current risks
        self.adjust_gamma(positive_risk, negative_risk)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # uPU learning
            x_out = objective * self.gamma  # Apply dynamic gamma to the objective

        return (x_out, objective, positive_risk, negative_risk,
                R_p_pos, R_u_neg, R_p_neg)  # Return current gamma value


class BalancedPUCEloss(nn.Module):
    """PUCEloss with explicit balance constraint"""
    def __init__(self, prior, loss, gamma=1, beta=0, balance_weight=1e1, objective=False):
        super(BalancedPUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # balance_term = torch.abs(positive_risk - negative_risk)
        balance_term = torch.abs(R_p_pos - (R_u_neg - self.prior*R_p_neg)/(1-self.prior))
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class BalancedPULBloss(nn.Module):
    """PULBloss with explicit balance constraint"""

    def __init__(self, prior, loss, gamma=1, beta=0, momentum=0, balance_weight=1e1, objective=False):
        super(BalancedPULBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
            1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Update args.beta using exponential moving average
        current_bound = (1 - self.prior) * R_p_pos.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        balance_term = torch.norm(R_p_pos - (R_u_neg - self.prior * R_p_neg) / (1 - self.prior), p=2)
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta + self.bound:
            objective = positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound


class PNCEloss(nn.Module):
    """Loss function for PN learning."""
    def __init__(self, prior, loss):
        super(PNCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss

    def forward(self, x_p, x_n, t_p, t_n):
        n_positive, n_negative = max(1.,
                                      t_p.sum().item()), max(
                                          1., (1. - t_n).sum().item())
        y_p = self.loss_func(x_p)
        y_n = self.loss_func(-x_n)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum((1 - self.prior) / n_negative * y_n)

        objective = positive_risk + negative_risk
        x_out = objective
        return x_out, positive_risk, negative_risk


class PULBloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma, beta, momentum=0, objective=False):
        super(PULBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,t_p.sum().item()), \
                                  max(1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(1 / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = self.prior * positive_risk + negative_risk
        # Update args.beta using exponential moving average
        current_bound = (1 - self.prior) * positive_risk.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        # objective = self.prior * positive_risk + torch.clamp(negative_risk, -self.beta+self.bound)
        # nnPU-lower bound learning
        if negative_risk.item() < - self.beta + self.bound:
            # negative_risk_ = torch.clamp(negative_risk, min=self.beta)
            # objective = positive_risk + negative_risk_
            objective = self.prior * positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound



class ScalePULoss(nn.Module):
    """
    Variance-Regularized PU Loss

    实现：
        L = R̂_PU(f) + λ * Ω_var(f)
    """

    def __init__(self, prior, loss, lambda_reg=0.1, beta=0., gamma=1.):
        super(ScalePULoss, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.loss_func = loss
        self.lambda_reg = lambda_reg
        self.beta = beta
        self.gamma = gamma

        self.positive = 1
        self.unlabeled = -1
        self.min_count = 1.

    def forward(self, outputs_p, outputs_u, targets_p, targets_u):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()
        if outputs_u.dim() > 1:
            outputs_u = outputs_u.squeeze()

        n_positive = max(self.min_count, (targets_p == self.positive).sum().float().item())
        n_unlabeled = max(self.min_count, (targets_u == self.unlabeled).sum().float().item())

        loss_pos_positive = self.loss_func(outputs_p)      # ℓ(f(x))
        loss_neg_positive = self.loss_func(-outputs_p)     # ℓ(-f(x))
        g_positive = loss_pos_positive - loss_neg_positive

        loss_neg_unlabeled = self.loss_func(-outputs_u)    # ℓ(-f(x_u))

        # π E_P[ℓ(f(x))]
        positive_risk = self.prior * torch.mean(loss_pos_positive)

        # E_U[ℓ(-f(x))] - π E_P[ℓ(-f(x))]
        negative_risk_from_unlabeled = torch.mean(loss_neg_unlabeled)
        negative_risk_from_positive = self.prior * torch.mean(loss_neg_positive)
        negative_risk = negative_risk_from_unlabeled - negative_risk_from_positive

        pu_loss = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            objective = objective - self.gamma * negative_risk
        else:
            objective = pu_loss

        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)
        var_reg = self.prior * g_var

        total_loss = objective + self.lambda_reg * var_reg

        return total_loss, positive_risk, negative_risk, var_reg

    @torch.no_grad()
    def get_variance_stats(self, outputs_p):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)

        return {
            "g_mean": g_mean.item(),
            "g_var": g_var.item(),
        }


class ScalePULossVarianceLambda(nn.Module):
    """
    Variance-Regularized PU Loss (brief 版本用)

    实现：
        L = R̂_PU(f) + λ * Ω_var(f)
    """

    def __init__(self, prior, loss, lambda_reg=0.1, beta=0., gamma=1.,
                 var_threshold=0.1, smooth_lambda=True):
        super(ScalePULossVarianceLambda, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.loss_func = loss
        self.lambda_reg = lambda_reg
        self.beta = beta
        self.gamma = gamma
        self.var_threshold = var_threshold
        self.smooth_lambda = smooth_lambda

        self.positive = 1
        self.unlabeled = -1
        self.min_count = 1.

    def forward(self, outputs_p, outputs_u, targets_p, targets_u):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()
        if outputs_u.dim() > 1:
            outputs_u = outputs_u.squeeze()

        n_positive = max(self.min_count, (targets_p == self.positive).sum().float().item())
        n_unlabeled = max(self.min_count, (targets_u == self.unlabeled).sum().float().item())

        loss_pos_positive = self.loss_func(outputs_p)  # ℓ(f(x))
        loss_neg_positive = self.loss_func(-outputs_p)  # ℓ(-f(x))
        g_positive = loss_pos_positive - loss_neg_positive

        loss_neg_unlabeled = self.loss_func(-outputs_u)  # ℓ(-f(x_u))

        # π E_P[ℓ(f(x))]
        positive_risk = self.prior * torch.mean(loss_pos_positive)

        # E_U[ℓ(-f(x))] - π E_P[ℓ(-f(x))]
        negative_risk_from_unlabeled = torch.mean(loss_neg_unlabeled)
        negative_risk_from_positive = self.prior * torch.mean(loss_neg_positive)
        negative_risk = negative_risk_from_unlabeled - negative_risk_from_positive

        pu_loss = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            objective = objective - self.gamma * negative_risk
        else:
            objective = pu_loss

        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)
        var_reg = self.prior * g_var

        if self.smooth_lambda:
            lambda_eff = self.lambda_reg * torch.sigmoid(10.0 * (g_var - self.var_threshold))
        else:
            if g_var.item() > self.var_threshold:
                lambda_eff = self.lambda_reg
            else:
                lambda_eff = 0.0

        total_loss = objective + lambda_eff * var_reg

        return total_loss, positive_risk, negative_risk, var_reg

    @torch.no_grad()
    def get_variance_stats(self, outputs_p):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)

        return {
            "g_mean": g_mean.item(),
            "g_var": g_var.item(),
        }


class FocalScalePULoss(nn.Module):
    def __init__(self, prior, loss, lambda_reg=0.1, beta=0., gamma=1.,
                 var_threshold=0.1, smooth_lambda=True,
                 focal_mode='clip', focal_ratio=0.5, clip_quantile=None):
        super(FocalScalePULoss, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.loss_func = loss
        self.lambda_reg = lambda_reg
        self.beta = beta
        self.gamma = gamma
        self.var_threshold = var_threshold
        self.smooth_lambda = smooth_lambda

        # Focal variance参数
        self.focal_mode = focal_mode
        self.focal_ratio = focal_ratio
        self.clip_quantile = clip_quantile

        self.positive = 1
        self.unlabeled = -1
        self.min_count = 1.

    def _compute_focal_variance(self, g_positive):
        n = g_positive.shape[0]

        if self.focal_mode == 'topk':
            k = max(1, int(n * self.focal_ratio))
            g_abs = torch.abs(g_positive)
            _, indices = torch.topk(g_abs, k)
            g_topk = g_positive[indices]
            g_mean = torch.mean(g_topk)
            focal_var = torch.mean((g_topk - g_mean) ** 2)

        elif self.focal_mode == 'sort':
            k = max(1, int(n * self.focal_ratio))
            _, indices = torch.topk(g_positive, k)  # 取最大的k个
            g_topk = g_positive[indices]
            g_mean = torch.mean(g_topk)
            focal_var = torch.mean((g_topk - g_mean) ** 2)

        elif self.focal_mode == 'clip':
            if self.clip_quantile is None:
                self.clip_quantile = (0.1, 0.9)
            q_low, q_high = self.clip_quantile

            g_sorted = torch.sort(g_positive)[0]
            idx_low = int(n * q_low)
            idx_high = int(n * q_high)
            g_min = g_sorted[idx_low]
            g_max = g_sorted[idx_high]

            # Clip
            g_clipped = torch.clamp(g_positive, min=g_min, max=g_max)
            g_mean = torch.mean(g_clipped)
            focal_var = torch.mean((g_clipped - g_mean) ** 2)

        else:
            g_mean = torch.mean(g_positive)
            focal_var = torch.mean((g_positive - g_mean) ** 2)

        return focal_var

    def forward(self, outputs_p, outputs_u, targets_p, targets_u):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()
        if outputs_u.dim() > 1:
            outputs_u = outputs_u.squeeze()

        n_positive = max(self.min_count, (targets_p == self.positive).sum().float().item())
        n_unlabeled = max(self.min_count, (targets_u == self.unlabeled).sum().float().item())

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        loss_neg_unlabeled = self.loss_func(-outputs_u)

        positive_risk = self.prior * torch.mean(loss_pos_positive)
        negative_risk_from_unlabeled = torch.mean(loss_neg_unlabeled)
        negative_risk_from_positive = self.prior * torch.mean(loss_neg_positive)
        negative_risk = negative_risk_from_unlabeled - negative_risk_from_positive

        pu_loss = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            objective = objective - self.gamma * negative_risk
        else:
            objective = pu_loss

        focal_var = self._compute_focal_variance(g_positive)
        var_reg = self.prior * focal_var

        if self.smooth_lambda:
            lambda_eff = self.lambda_reg * torch.sigmoid(10.0 * (focal_var - self.var_threshold))
        else:
            if focal_var.item() > self.var_threshold:
                lambda_eff = self.lambda_reg
            else:
                lambda_eff = 0.0

        total_loss = objective + lambda_eff * var_reg

        return total_loss, positive_risk, negative_risk, var_reg

    @torch.no_grad()
    def get_variance_stats(self, outputs_p):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)

        focal_var = self._compute_focal_variance(g_positive)

        return {
            "g_mean": g_mean.item(),
            "g_var": g_var.item(),
            "focal_var": focal_var.item(),
        }


class GeometricScalePULoss(nn.Module):
    """
    ScalePU with Geometric Regularization

    Objective:
        L = R̂_PU(f) + λ · Ω_var(f) + γ · [Ω_compact(h) + β · Ω_sep(h)]
    """

    def __init__(self, prior, loss,
                 lambda_reg=0.1,
                 gamma_geo=0.01,
                 beta_sep=1.0,
                 margin=1.0,
                 var_threshold=0.1,
                 smooth_lambda=True,
                 focal_mode='clip',
                 focal_ratio=0.5,
                 clip_quantile=None,
                 similarity_temp=1.0,
                 beta_nnpu=0.,
                 gamma_nnpu=1.):
        super(GeometricScalePULoss, self).__init__()

        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.loss_func = loss

        # Variance regularization parameters
        self.lambda_reg = lambda_reg
        self.var_threshold = var_threshold
        self.smooth_lambda = smooth_lambda

        # Geometric regularization parameters
        self.gamma_geo = gamma_geo
        self.beta_sep = beta_sep
        self.margin = margin
        self.similarity_temp = similarity_temp

        # Focal variance parameters
        self.focal_mode = focal_mode
        self.focal_ratio = focal_ratio
        self.clip_quantile = clip_quantile

        # nnPU parameters
        self.beta_nnpu = beta_nnpu
        self.gamma_nnpu = gamma_nnpu

        self.positive = 1
        self.unlabeled = -1
        self.min_count = 1.

    def _compute_focal_variance(self, g_positive):
        n = g_positive.shape[0]

        if self.focal_mode == 'topk':
            k = max(1, int(n * self.focal_ratio))
            g_abs = torch.abs(g_positive)
            _, indices = torch.topk(g_abs, k)
            g_topk = g_positive[indices]
            g_mean = torch.mean(g_topk)
            focal_var = torch.mean((g_topk - g_mean) ** 2)

        elif self.focal_mode == 'sort':
            k = max(1, int(n * self.focal_ratio))
            _, indices = torch.topk(g_positive, k)
            g_topk = g_positive[indices]
            g_mean = torch.mean(g_topk)
            focal_var = torch.mean((g_topk - g_mean) ** 2)

        elif self.focal_mode == 'clip':
            if self.clip_quantile is None:
                self.clip_quantile = (0.1, 0.9)
            q_low, q_high = self.clip_quantile

            g_sorted = torch.sort(g_positive)[0]
            idx_low = int(n * q_low)
            idx_high = int(n * q_high)
            g_min = g_sorted[idx_low]
            g_max = g_sorted[idx_high]

            # Clip
            g_clipped = torch.clamp(g_positive, min=g_min, max=g_max)
            g_mean = torch.mean(g_clipped)
            focal_var = torch.mean((g_clipped - g_mean) ** 2)

        else:
            g_mean = torch.mean(g_positive)
            focal_var = torch.mean((g_positive - g_mean) ** 2)

        return focal_var

    def _compute_compactness(self, features_p):
        c_p = torch.mean(features_p, dim=0, keepdim=True)

        distances_sq = torch.sum((features_p - c_p) ** 2, dim=1)
        compactness = torch.mean(distances_sq)

        return compactness, c_p

    def _compute_similarity(self, features, prototype):
        features_norm = F.normalize(features, p=2, dim=1)
        prototype_norm = F.normalize(prototype, p=2, dim=1)

        cosine_sim = torch.mm(features_norm, prototype_norm.t()).squeeze(1)

        similarities = torch.sigmoid(cosine_sim / self.similarity_temp)

        return similarities

    def _compute_separation(self, features_u, c_p):
        distances_sq = torch.sum((features_u - c_p) ** 2, dim=1)

        similarities = self._compute_similarity(features_u, c_p)
        weights = 1.0 - similarities

        hinge_losses = torch.clamp(self.margin - distances_sq, min=0.0)
        separation = torch.mean(weights * hinge_losses)

        return separation

    def forward(self, outputs_p, outputs_u, targets_p, targets_u,
                features_p=None, features_u=None):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()
        if outputs_u.dim() > 1:
            outputs_u = outputs_u.squeeze()

        n_positive = max(self.min_count, (targets_p == self.positive).sum().float().item())
        n_unlabeled = max(self.min_count, (targets_u == self.unlabeled).sum().float().item())

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        loss_neg_unlabeled = self.loss_func(-outputs_u)

        positive_risk = self.prior * torch.mean(loss_pos_positive)
        negative_risk_from_unlabeled = torch.mean(loss_neg_unlabeled)
        negative_risk_from_positive = self.prior * torch.mean(loss_neg_positive)
        negative_risk = negative_risk_from_unlabeled - negative_risk_from_positive

        pu_loss = positive_risk + negative_risk

        if negative_risk.item() < -self.beta_nnpu:
            objective = positive_risk - self.beta_nnpu
            objective = objective - self.gamma_nnpu * negative_risk
        else:
            objective = pu_loss

        focal_var = self._compute_focal_variance(g_positive)
        var_reg = self.prior * focal_var

        if self.smooth_lambda:
            lambda_eff = self.lambda_reg * torch.sigmoid(
                10.0 * (focal_var - self.var_threshold)
            )
        else:
            if focal_var.item() > self.var_threshold:
                lambda_eff = self.lambda_reg
            else:
                lambda_eff = 0.0

        geo_reg = torch.tensor(0.0, device=outputs_p.device)
        compactness = torch.tensor(0.0, device=outputs_p.device)
        separation = torch.tensor(0.0, device=outputs_p.device)

        if features_p is not None:
            compactness, c_p = self._compute_compactness(features_p)
            geo_reg = compactness

            if features_u is not None:
                separation = self._compute_separation(features_u, c_p)
                geo_reg = compactness + self.beta_sep * separation

        total_loss = objective + lambda_eff * var_reg + self.gamma_geo * geo_reg

        loss_dict = {
            'total_loss': total_loss.item(),
            'pu_loss': objective.item(),
            'positive_risk': positive_risk.item(),
            'negative_risk': negative_risk.item(),
            'var_reg': var_reg.item(),
            'focal_var': focal_var.item(),
            'lambda_eff': lambda_eff if isinstance(lambda_eff, float) else lambda_eff.item(),
            'geo_reg': geo_reg.item(),
            'compactness': compactness.item(),
            'separation': separation.item(),
        }

        return total_loss, positive_risk, negative_risk, var_reg, geo_reg, loss_dict

    @torch.no_grad()
    def get_statistics(self, outputs_p, features_p=None, features_u=None):
        if outputs_p.dim() > 1:
            outputs_p = outputs_p.squeeze()

        loss_pos_positive = self.loss_func(outputs_p)
        loss_neg_positive = self.loss_func(-outputs_p)
        g_positive = loss_pos_positive - loss_neg_positive

        # Variance statistics
        g_mean = torch.mean(g_positive)
        g_var = torch.mean((g_positive - g_mean) ** 2)
        focal_var = self._compute_focal_variance(g_positive)

        stats = {
            "g_mean": g_mean.item(),
            "g_var": g_var.item(),
            "focal_var": focal_var.item(),
        }

        if features_p is not None:
            compactness, c_p = self._compute_compactness(features_p)
            stats["compactness"] = compactness.item()

            avg_dist = torch.sqrt(compactness).item()
            stats["avg_distance_to_centroid"] = avg_dist

            if features_u is not None:
                separation = self._compute_separation(features_u, c_p)
                stats["separation"] = separation.item()

                distances_sq = torch.sum((features_u - c_p) ** 2, dim=1)
                stats["avg_unlabeled_distance"] = torch.sqrt(torch.mean(distances_sq)).item()

        return stats