import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
import math


class nnPUSBloss(nn.Module):
    """Loss function for PUSB learning."""
    def __init__(self, prior, gamma=1, beta=0):
        super(nnPUSBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.positive = 1
        self.unlabeled = -1
        self.eps = 1e-7

    def forward(self, x, t):
        # clip the predict value to make the following optimization problem well-defined.
        x = torch.clamp(x, min=self.eps, max=1 - self.eps)

        t = t[:, None]
        # positive: if positive,1 else 0
        # unlabeled: if unlabeled,1 else 0
        positive, unlabeled = (t == self.positive).float(), (
            t == self.unlabeled).float()
        n_positive, n_unlabeled = max(1.,
                                      positive.sum().item()), max(
                                          1.,
                                          unlabeled.sum().item())
        y_positive = -torch.log(x)
        y_unlabeled = -torch.log(1 - x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (unlabeled / n_unlabeled - self.prior * positive / n_positive) *
            y_unlabeled)

        # print("positive_risk:", positive_risk.item())
        # print("negative_risk:", negative_risk.item())

        objective = positive_risk + negative_risk
        # nnPU learning
        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective
        return x_out


def nnPUSB_loss(x, t, prior):
    """wrapper of loss function for non-negative PU with a select bias learning

        .. math::
            L_[\\-pi E_1[\\log(f(x))]+\\max(-E_X[\\log[1-f(x)]+\\pi E_1[\\log(1-f(x))], \\beta)+ R(f)

    Args:
        x (~torch.Variable): Input variable.
            The shape of ``x`` should be (:math:`N`, 1).
        t (~torch.Variable): Target variable for regression.
            The shape of ``t`` should be (:math:`N`, ).
        prior (float): Constant variable for class prior.
        loss (~torch.function): loss function.
            The loss function should be non-increasing.
        R(f):  L2 regularization of f, in pytorch, L2 regularisation is mysteriously added in the Optimization functions like ``Adam``

    Returns:
        ~torch.Variable: A variable object holding a scalar array of the
            PU loss.

    See:
        Masahiro Kato and Takeshi Teshima and Junya Honda.
        "Learning from Positive and Unlabeled Data with a Selection Bias."
        International Conference on Learning Representations. 2019.
    """
    return nnPUSBloss(prior=prior)(x, t)


class nnPUloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True):
        super(nnPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.loss_func = loss
        self.nnpu = nnpu
        self.positive = 1
        self.unlabeled = -1

    def forward(self, x, t):
        t = t[:, None]
        # positive: if positive,1 else 0
        # unlabeled: if unlabeled,1 else 0
        positive, unlabeled = (t == self.positive).float(), (
            t == self.unlabeled).float()
        n_positive, n_unlabeled = max(1.,
                                      positive.sum().item()), max(
                                          1.,
                                          unlabeled.sum().item())
        y_positive = self.loss_func(x)
        y_unlabeled = self.loss_func(-x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (unlabeled / n_unlabeled - self.prior * positive / n_positive) *
            y_unlabeled)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            # nPU learning
            else:
                x_out = objective
        else:
            x_out = objective
        return x_out, negative_risk


def nnPU_loss(x, t, prior, loss, nnpu=True):
    """wrapper of loss function for non-negative PU learning

        .. math::
            L_[\\pi E_1[l(f(x))]+\\max(E_X[l(-f(x))]-\\pi E_1[l(-f(x))], \\beta)

    Args:
        x (~torch.Variable): Input variable.
            The shape of ``x`` should be (:math:`N`, 1).
        t (~torch.Variable): Target variable for regression.
            The shape of ``t`` should be (:math:`N`, ).
        prior (float): Constant variable for class prior.
        loss (~torch.function): loss function.
            The loss function should be non-increasing.

    Returns:
        ~torch.Variable: A variable object holding a scalar array of the
            PU loss.

    See:
        Ryuichi Kiryo, Gang Niu, Marthinus Christoffel du Plessis, and Masashi Sugiyama.
        "Positive-Unlabeled Learning with Non-Negative Risk Estimator."
        Advances in neural information processing systems. 2017.
        du Plessis, Marthinus Christoffel, Gang Niu, and Masashi Sugiyama.
        "Convex formulation for learning from positive and unlabeled data."
        Proceedings of The 32nd International Conference on Machine Learning. 2015.
    """
    return nnPUloss(prior=prior, loss=loss, nnpu=nnpu)(x, t)


class PNloss(nn.Module):
    """Loss function for PN learning."""
    def __init__(self, prior, loss):
        super(PNloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.positive = 1
        self.negative = -1

    def forward(self, x, t):
        t = t[:, None]
        # positive: if positive,1 else 0
        # negative: if negative,1 else 0
        positive, negative = (t == self.positive).float(), (
            t == self.negative).float()
        n_positive, n_negative = max(1.,
                                     positive.sum().item()), max(
                                         1.,
                                         negative.sum().item())
        y_positive = self.loss_func(x)
        y_negative = self.loss_func(-x)
        positive_risk = torch.sum(self.prior * positive / n_positive *
                                  y_positive)
        negative_risk = torch.sum(
            (1. - self.prior) * negative / n_negative * y_negative)

        objective = positive_risk + negative_risk
        return objective


def PN_loss(x, t, prior, loss):
    """wrapper of loss function for PN learning

        .. math::
            L_[\\pi E_p[l(f(x))]+L_[(1. - \\pi)E_n[l(f(x))]]

    Args:
        x (~torch.Variable): Input variable.
            The shape of ``x`` should be (:math:`N`, 1).
        t (~torch.Variable): Target variable for regression.
            The shape of ``t`` should be (:math:`N`, ).
        prior (float): Constant variable for class prior.
        loss (~torch.function): loss function.
            The loss function should be non-increasing.

    Returns:
        ~torch.Variable: A variable object holding a scalar array of the
            PN loss.
    """
    return PNloss(prior=prior, loss=loss)(x, t)


# class nnPUloss1(nn.Module):
#     """Loss function for PU learning."""
#     def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True):
#         super(nnPUloss1, self).__init__()
#         if not 0 < prior < 1:
#             raise NotImplementedError("The class prior should be in (0, 1)")
#         self.prior = prior
#         self.gamma = gamma
#         self.beta = beta
#         self.loss_func = loss
#         self.nnpu = nnpu
#         self.positive = 1
#         self.unlabeled = -1
#     def forward(self, x, t):
#         t = t[:, None]
#         # positive: if positive,1 else 0
#         # unlabeled: if unlabeled,1 else 0
#         positive, unlabeled = (t == self.positive).float(), (
#             t == self.unlabeled).float()
#         n_positive, n_unlabeled = max(1.,
#                                       positive.sum().item()), max(
#                                           1.,
#                                           unlabeled.sum().item())
#         # y_positive = self.loss_func(x)
#         # y_unlabeled = self.loss_func(-x)
#         positive_risk = torch.sum(self.prior * positive / n_positive *
#                                   torch.log(1 + torch.exp(-x)))
#         negative_risk = torch.sum(
#             (unlabeled / n_unlabeled - self.prior * positive / n_positive) *
#             torch.log(1 + torch.exp(x)))
#         objective = positive_risk + negative_risk
#         # if self.nnpu:
#         #     # nnPU learning
#         #     if negative_risk.item() < -self.beta:
#         #         objective = positive_risk - self.beta
#         #         x_out = -self.gamma * negative_risk
#         #     else:
#         #         x_out = objective
#         # else:
#         #     x_out = objective
#         return objective

# def nnPU_loss1(x, t, prior, loss, nnpu=True):
#     """wrapper of loss function for non-negative PU learning
#         .. math::
#             L_[\\pi E_1[l(f(x))]+\\max(E_X[l(-f(x))]-\\pi E_1[l(-f(x))], \\beta)
#     Args:
#         x (~torch.Variable): Input variable.
#             The shape of ``x`` should be (:math:`N`, 1).
#         t (~torch.Variable): Target variable for regression.
#             The shape of ``t`` should be (:math:`N`, ).
#         prior (float): Constant variable for class prior.
#         loss (~torch.function): loss function.
#             The loss function should be non-increasing.
#     Returns:
#         ~torch.Variable: A variable object holding a scalar array of the
#             PU loss.
#     See:
#         Ryuichi Kiryo, Gang Niu, Marthinus Christoffel du Plessis, and Masashi Sugiyama.
#         "Positive-Unlabeled Learning with Non-Negative Risk Estimator."
#         Advances in neural information processing systems. 2017.
#         du Plessis, Marthinus Christoffel, Gang Niu, and Masashi Sugiyama.
#         "Convex formulation for learning from positive and unlabeled data."
#         Proceedings of The 32nd International Conference on Machine Learning. 2015.
#     """
#     return nnPUloss1(prior=prior, loss=loss, nnpu=nnpu)(x, t)


class PUCEloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True, objective=False):
        super(PUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # uPU learning
            x_out = objective
        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class ABSPUloss(nn.Module):
    """Loss function for ABS-PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, objective=False):
        super(ABSPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = positive_risk + negative_risk
        x_out = positive_risk + torch.abs(negative_risk)
        return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class DistPUloss(nn.Module):
    """Loss function for ABS-PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0):
        super(DistPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)
        x_out = 2 * positive_risk + torch.abs(R_u_neg - self.prior)
        return x_out, x_out, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class FOPULoss(torch.nn.Module):
    """
    PU learning loss with class balance regularization.
    Combines non-negative PU learning with TPR balance between positive and negative classes.
    """
    def __init__(self, prior, loss, gamma=1, beta=0, lam_f=0.01, objective=False):
        super(FOPULoss, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss if loss is not None else lambda x: F.softplus(-x)
        self.gamma = gamma
        self.beta = beta
        self.objective = objective
        self.lam_f = lam_f

    def compute_fairness_loss(self, output, target):
        """Computes class balance regularization loss"""
        kappa = lambda z: torch.relu(1 + z)
        delta = lambda z: 1 - torch.relu(1 - z)
        N = len(output)

        # Convert to numpy for indexing
        y = target.cpu().detach().numpy()

        # Calculate class rates
        p_pos = np.sum(y == 1) / N
        p_neg = np.sum(y == -1) / N

        # Get indices for positive and negative samples
        idx_pos = np.where(y == 1)[0]
        idx_neg = np.where(y == -1)[0]

        # Get predictions
        pred_pos = output[idx_pos]
        pred_neg = output[idx_neg]

        # Calculate TPR difference
        tpr_diff = (torch.sum(pred_pos >= 0) / len(idx_pos) -
                    torch.sum(pred_neg >= 0) / len(idx_neg))

        # Compute fairness loss based on TPR difference
        if tpr_diff.item() > 0:
            fairness_loss = ((torch.sum(kappa(pred_pos)) / p_pos +
                              torch.sum(kappa(-pred_neg)) / p_neg) / N - 1)
        else:
            fairness_loss = -1 * ((torch.sum(delta(pred_pos)) / p_pos +
                                   torch.sum(delta(-pred_neg)) / p_neg) / N - 1)

        return self.lam_f * fairness_loss

    def forward(self, x_p, x_u, t_p, t_u):
        # Combine data for fairness computation
        x_combined = torch.cat([x_p, x_u], dim=0)
        t_combined = torch.cat([t_p, t_u], dim=0)

        # Compute standard PU loss
        n_positive = max(1., t_p.sum().item())
        n_unlabeled = max(1., (-t_u).sum().item())

        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        # Compute class balance regularization
        fairness_loss = self.compute_fairness_loss(x_combined, t_combined)

        # Combine losses
        objective += fairness_loss
        x_out += objective

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class BasePUCEloss(nn.Module):
    """Base PU learning loss function"""

    def __init__(self, prior, loss_func=None, gamma=1, beta=0, nnpu=False):
        super(BasePUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.loss_func = loss_func if loss_func is not None else nn.Sigmoid()

    def calculate_risks(self, x_p, x_u, t_p, t_u):
        """Calculate different risk components"""
        n_positive = max(1., t_p.sum().item())
        n_unlabeled = max(1., (-t_u).sum().item())

        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        return positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class DynamicPUCELoss(nn.Module):
    """Loss function for PU learning with dynamic gamma adjustment."""

    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=False,
                 gamma_min=0.1, gamma_max=5.0, adjustment_rate=0.1):
        super(DynamicPUCELoss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu

        # Parameters for dynamic gamma adjustment
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.adjustment_rate = adjustment_rate
        self.moving_avg_pos_risk = None
        self.moving_avg_neg_risk = None
        self.momentum = 0.9  # Momentum for moving average calculation

    def adjust_gamma(self, positive_risk, negative_risk):
        """Dynamically adjust gamma based on the balance of positive and negative risks."""
        # Update moving averages
        if self.moving_avg_pos_risk is None:
            self.moving_avg_pos_risk = positive_risk.item()
            self.moving_avg_neg_risk = negative_risk.item()
        else:
            self.moving_avg_pos_risk = (self.momentum * self.moving_avg_pos_risk +
                                        (1 - self.momentum) * positive_risk.item())
            self.moving_avg_neg_risk = (self.momentum * self.moving_avg_neg_risk +
                                        (1 - self.momentum) * negative_risk.item())

        # Calculate risk ratio
        risk_ratio = abs(self.moving_avg_pos_risk / (self.moving_avg_neg_risk + 1e-8))

        # Adjust gamma based on risk ratio
        if risk_ratio > 1:
            # Positive risk is larger, increase gamma to give more weight to negative risk
            self.gamma = min(self.gamma_max,
                             self.gamma * (1 + self.adjustment_rate * (risk_ratio - 1)))
        else:
            # Negative risk is larger, decrease gamma to give less weight to negative risk
            self.gamma = max(self.gamma_min,
                             self.gamma / (1 + self.adjustment_rate * (1 / risk_ratio - 1)))

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1., t_p.sum().item()), max(1., (-t_u).sum().item())

        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Adjust gamma based on current risks
        self.adjust_gamma(positive_risk, negative_risk)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # uPU learning
            x_out = objective * self.gamma  # Apply dynamic gamma to the objective

        return (x_out, objective, positive_risk, negative_risk,
                R_p_pos, R_u_neg, R_p_neg)  # Return current gamma value


class BalancedPUCEloss(nn.Module):
    """PUCEloss with explicit balance constraint"""
    def __init__(self, prior, loss, gamma=1, beta=0, balance_weight=1e1, objective=False):
        super(BalancedPUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # balance_term = torch.abs(positive_risk - negative_risk)
        balance_term = torch.abs(R_p_pos - (R_u_neg - self.prior*R_p_neg)/(1-self.prior))
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class BalancedPULBloss(nn.Module):
    """PULBloss with explicit balance constraint"""

    def __init__(self, prior, loss, gamma=1, beta=0, momentum=0, balance_weight=1e1, objective=False):
        super(BalancedPULBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
            1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Update args.beta using exponential moving average
        current_bound = (1 - self.prior) * R_p_pos.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        balance_term = torch.norm(R_p_pos - (R_u_neg - self.prior * R_p_neg) / (1 - self.prior), p=2)
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta + self.bound:
            objective = positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound


class BalancedPUCEloss2(nn.Module):
    """PUCEloss with explicit balance constraint"""
    def __init__(self, prior, loss, gamma=1, beta=0, balance_weight=1e-1, objective=False):
        super(BalancedPUCEloss2, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

    def forward(self, x_p, x_u, t_p, t_u):
        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
                                          1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        balance_term = torch.abs(positive_risk - negative_risk)
        # balance_term = torch.abs(R_p_pos - (R_u_neg - self.prior*R_p_neg)/(1-self.prior))
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class BalancedPULBloss2(nn.Module):
    """PULBloss with explicit balance constraint"""

    def __init__(self, prior, loss, gamma=1, beta=0, momentum=0, balance_weight=1e-1, objective=False):
        super(BalancedPULBloss2, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.balance_weight = balance_weight
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,
                                      t_p.sum().item()), max(
            1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Update args.beta using exponential moving average
        current_bound = (1 - self.prior) * R_p_pos.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        balance_term = torch.abs(positive_risk - negative_risk)
        # balance_term = torch.abs(R_p_pos - (R_u_neg - self.prior * R_p_neg) / (1 - self.prior))
        objective = positive_risk + negative_risk

        if negative_risk.item() < -self.beta + self.bound:
            objective = positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        objective += self.balance_weight * balance_term
        x_out += self.balance_weight * balance_term

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound


class CurriculumPUCEloss(nn.Module):
    """PUCEloss with curriculum learning strategy"""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=False, curriculum_rate=0.3):
        super(CurriculumPUCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")

        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.curriculum_rate = curriculum_rate

    def get_curriculum_weight(self, epoch, max_epochs):
        # Gradually increase the weight from 0 to 1
        return min(1.0, epoch / (self.curriculum_rate * max_epochs))

    def forward(self, x_p, x_u, t_p, t_u, epoch, max_epochs):
        # Calculate basic statistics
        n_positive = max(1., t_p.sum().item())
        n_unlabeled = max(1., (-t_u).sum().item())

        # Calculate losses
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)

        # Calculate risks
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)

        # Apply curriculum learning weight to negative risk
        curr_weight = self.get_curriculum_weight(epoch, max_epochs)
        weighted_negative_risk = curr_weight * negative_risk

        # Calculate additional metrics for monitoring
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        # Combine risks with curriculum weighting
        objective = positive_risk + weighted_negative_risk

        if self.nnpu:
            # Non-negative PU learning with curriculum
            if weighted_negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # Unbiased PU learning with curriculum
            x_out = objective

        return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


class RegularizedPUCEloss(BasePUCEloss):
    """PUCEloss with L2 regularization"""

    def __init__(self, prior, loss_func=None, gamma=1, beta=0,
                 reg_lambda=0.01):
        super(RegularizedPUCEloss, self).__init__(prior, loss_func, gamma, beta)
        self.reg_lambda = reg_lambda

    def forward(self, x_p, x_u, t_p, t_u):
        pos_risk, neg_risk, R_p_pos, R_u_neg, R_p_neg = self.calculate_risks(
            x_p, x_u, t_p, t_u)

        # Add L2 regularization
        l2_reg = torch.norm(x_p, 2) + torch.norm(x_u, 2)
        objective = pos_risk + neg_risk + self.reg_lambda * l2_reg

        if self.nnpu:
            if neg_risk.item() < -self.beta:
                objective = pos_risk - self.beta + self.reg_lambda * l2_reg

        return objective, objective, pos_risk, neg_risk, R_p_pos, R_u_neg, R_p_neg


class CombinedPUCEloss(nn.Module):
    """Combines multiple PU learning loss strategies with configurable weights"""

    def __init__(self, prior, loss_func=None, gamma=1, beta=0,
                 balance_weight=0.5, reg_lambda=0.01, curriculum_rate=0.3,
                 gamma_min=0.1, gamma_max=10.0,
                 strategy_weights={'dynamic': 1.0, 'balanced': 1.0,
                                   'regularized': 1.0, 'curriculum': 1.0}):
        super(CombinedPUCEloss, self).__init__()

        # Initialize strategy weights
        self.strategy_weights = strategy_weights

        # Initialize individual loss modules only if their weight > 0
        if self.strategy_weights.get('dynamic', 0) > 0:
            self.dynamic_loss = DynamicPUCEloss(
                prior, loss_func, gamma, beta, gamma_min, gamma_max)

        if self.strategy_weights.get('balanced', 0) > 0:
            self.balanced_loss = BalancedPUCEloss(
                prior, loss_func, gamma, beta, balance_weight)

        if self.strategy_weights.get('curriculum', 0) > 0:
            self.curriculum_loss = CurriculumPUCEloss(
                prior, loss_func, gamma, beta, curriculum_rate)

        if self.strategy_weights.get('regularized', 0) > 0:
            self.regularized_loss = RegularizedPUCEloss(
                prior, loss_func, gamma, beta, reg_lambda)

        # Calculate normalization factor for weights
        self.weight_sum = sum(w for w in self.strategy_weights.values() if w > 0)
        if self.weight_sum == 0:
            raise ValueError("At least one strategy weight must be greater than 0")

    def forward(self, x_p, x_u, t_p, t_u, epoch=None, max_epochs=None):
        losses = []
        pos_risks = []
        neg_risks = []
        R_p_poss = []
        R_u_negs = []
        R_p_negs = []

        # Dynamic loss
        if self.strategy_weights.get('dynamic', 0) > 0:
            dynamic_out, _, pos_risk1, neg_risk1, R_p_pos1, R_u_neg1, R_p_neg1 = \
                self.dynamic_loss(x_p, x_u, t_p, t_u)
            losses.append(dynamic_out * self.strategy_weights['dynamic'])
            pos_risks.append(pos_risk1 * self.strategy_weights['dynamic'])
            neg_risks.append(neg_risk1 * self.strategy_weights['dynamic'])
            R_p_poss.append(R_p_pos1 * self.strategy_weights['dynamic'])
            R_u_negs.append(R_u_neg1 * self.strategy_weights['dynamic'])
            R_p_negs.append(R_p_neg1 * self.strategy_weights['dynamic'])

        # Balanced loss
        if self.strategy_weights.get('balanced', 0) > 0:
            balanced_out, _, pos_risk2, neg_risk2, R_p_pos2, R_u_neg2, R_p_neg2 = \
                self.balanced_loss(x_p, x_u, t_p, t_u)
            losses.append(balanced_out * self.strategy_weights['balanced'])
            pos_risks.append(pos_risk2 * self.strategy_weights['balanced'])
            neg_risks.append(neg_risk2 * self.strategy_weights['balanced'])
            R_p_poss.append(R_p_pos2 * self.strategy_weights['balanced'])
            R_u_negs.append(R_u_neg2 * self.strategy_weights['balanced'])
            R_p_negs.append(R_p_neg2 * self.strategy_weights['balanced'])

        # Regularized loss
        if self.strategy_weights.get('regularized', 0) > 0:
            regularized_out, _, pos_risk3, neg_risk3, R_p_pos3, R_u_neg3, R_p_neg3 = \
                self.regularized_loss(x_p, x_u, t_p, t_u)
            losses.append(regularized_out * self.strategy_weights['regularized'])
            pos_risks.append(pos_risk3 * self.strategy_weights['regularized'])
            neg_risks.append(neg_risk3 * self.strategy_weights['regularized'])
            R_p_poss.append(R_p_pos3 * self.strategy_weights['regularized'])
            R_u_negs.append(R_u_neg3 * self.strategy_weights['regularized'])
            R_p_negs.append(R_p_neg3 * self.strategy_weights['regularized'])

        # Curriculum loss
        if epoch is not None and max_epochs is not None and self.strategy_weights.get('curriculum', 0) > 0:
            curriculum_out, _, pos_risk4, neg_risk4, R_p_pos4, R_u_neg4, R_p_neg4 = \
                self.curriculum_loss(x_p, x_u, t_p, t_u, epoch, max_epochs)
            losses.append(curriculum_out * self.strategy_weights['curriculum'])
            pos_risks.append(pos_risk4 * self.strategy_weights['curriculum'])
            neg_risks.append(neg_risk4 * self.strategy_weights['curriculum'])
            R_p_poss.append(R_p_pos4 * self.strategy_weights['curriculum'])
            R_u_negs.append(R_u_neg4 * self.strategy_weights['curriculum'])
            R_p_negs.append(R_p_neg4 * self.strategy_weights['curriculum'])

        # Combine losses with normalization
        combined_out = sum(losses) / self.weight_sum
        combined_pos_risk = sum(pos_risks) / self.weight_sum
        combined_neg_risk = sum(neg_risks) / self.weight_sum
        combined_R_p_pos = sum(R_p_poss) / self.weight_sum
        combined_R_u_neg = sum(R_u_negs) / self.weight_sum
        combined_R_p_neg = sum(R_p_negs) / self.weight_sum

        return (combined_out, combined_out, combined_pos_risk, combined_neg_risk,
                combined_R_p_pos, combined_R_u_neg, combined_R_p_neg)


class PNCEloss(nn.Module):
    """Loss function for PN learning."""
    def __init__(self, prior, loss):
        super(PNCEloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss

    def forward(self, x_p, x_n, t_p, t_n):
        n_positive, n_negative = max(1.,
                                      t_p.sum().item()), max(
                                          1., (1. - t_n).sum().item())
        y_p = self.loss_func(x_p)
        y_n = self.loss_func(-x_n)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum((1 - self.prior) / n_negative * y_n)

        objective = positive_risk + negative_risk
        x_out = objective
        return x_out, positive_risk, negative_risk


class PULBloss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma, beta, momentum=0, objective=False):
        super(PULBloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,t_p.sum().item()), \
                                  max(1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(1 / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = self.prior * positive_risk + negative_risk
        # Update args.beta using exponential moving average
        current_bound = (1 - self.prior) * positive_risk.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        # objective = self.prior * positive_risk + torch.clamp(negative_risk, -self.beta+self.bound)
        # nnPU-lower bound learning
        if negative_risk.item() < - self.beta + self.bound:
            # negative_risk_ = torch.clamp(negative_risk, min=self.beta)
            # objective = positive_risk + negative_risk_
            objective = self.prior * positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound


class PULBloss2(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma, beta, momentum=0, objective=False):
        super(PULBloss2, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.objective = objective

        self.bound = 0
        self.bound_ema = 0
        self.momentum = momentum    # Exponential decay rate, can be adjusted

    def forward(self, x_p, x_u, t_p, t_u, bound):
        # Initialize beta EMA
        self.bound = bound
        self.bound_ema = bound

        n_positive, n_unlabeled = max(1.,t_p.sum().item()), \
                                  max(1., (- t_u).sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(1 / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive * y_u)

        objective = self.prior * positive_risk + negative_risk
        # Update args.beta using exponential moving average
        # current_bound = (1 - self.prior) * positive_risk.item()
        current_bound = self.prior * positive_risk.item()
        self.bound_ema = self.bound_ema * self.momentum + current_bound * (1 - self.momentum)
        self.bound = self.bound_ema

        # nnPU-lower bound learning
        if negative_risk.item() < - self.beta + self.bound:
            objective = self.prior * positive_risk - self.beta + self.bound
            x_out = -self.gamma * negative_risk
        else:
            x_out = objective

        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg, self.bound


class SoftnnPUloss(nn.Module):
    """Loss function for PU learning with soft labels."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True):
        super(SoftnnPUloss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.loss_func = loss
        self.nnpu = nnpu

    def forward(self, x, t):
        # t = t[:, None]
        n_positive, n_unlabeled = max(1.,
                                      t.sum().item()), max(
                                          1., (1. - t).sum().item())

        y_positive = self.loss_func(x)
        y_unlabeled = self.loss_func(-x)

        # positive_risk = torch.sum(t / n_positive * y_positive)
        # negative_risk = torch.sum((1. - t) / n_unlabeled * y_unlabeled)
        # x_out = positive_risk + negative_risk

        positive_risk = torch.sum(self.prior * t / n_positive * y_positive)
        negative_risk = torch.sum(
            ((1. - t) / n_unlabeled - self.prior * t / n_positive) *
            y_unlabeled)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            x_out = objective
        return x_out


def soft_nnPU_loss(x, t, prior, loss, nnpu=True):
    """wrapper of loss function for non-negative PU learning

        .. math::
            L_[\\pi E_1[l(f(x))]+\\max(E_X[l(-f(x))]-\\pi E_1[l(-f(x))], \\beta)

    Args:
        x (~torch.Variable): Input variable.
            The shape of ``x`` should be (:math:`N`, 1).
        t (~torch.Variable): Target variable for regression.
            The shape of ``t`` should be (:math:`N`, ).
        prior (float): Constant variable for class prior.
        loss (~torch.function): loss function.
            The loss function should be non-increasing.

    Returns:
        ~torch.Variable: A variable object holding a scalar array of the
            PU loss.

    See:
        Ryuichi Kiryo, Gang Niu, Marthinus Christoffel du Plessis, and Masashi Sugiyama.
        "Positive-Unlabeled Learning with Non-Negative Risk Estimator."
        Advances in neural information processing systems. 2017.
        du Plessis, Marthinus Christoffel, Gang Niu, and Masashi Sugiyama.
        "Convex formulation for learning from positive and unlabeled data."
        Proceedings of The 32nd International Conference on Machine Learning. 2015.
    """
    return SoftnnPUloss(prior=prior, loss=loss, nnpu=nnpu)(x, t)


class PUCEloss2(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss, gamma=1, beta=0, nnpu=True, objective=False):
        super(PUCEloss2, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.loss_func = loss
        self.gamma = gamma
        self.beta = beta
        self.nnpu = nnpu
        self.objective = objective

    def forward(self, x_p, x_u, x_p2, t_p, t_u, t_p2):
        n_positive, n_unlabeled, n_positive2 = max(1., t_p.sum().item()), \
                                               max(1., (- t_u).sum().item()), \
                                               max(1., t_p2.sum().item())
        y_p = self.loss_func(x_p)
        y_u = self.loss_func(-x_p2)
        y_unlabeled = self.loss_func(-x_u)
        positive_risk = torch.sum(self.prior / n_positive * y_p)
        negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
            self.prior / n_positive2 * y_u)
        R_p_pos = torch.sum(1 / n_positive * y_p)
        R_u_neg = torch.sum(y_unlabeled / n_unlabeled)
        R_p_neg = torch.sum(1 / n_positive2 * y_u)

        objective = positive_risk + negative_risk
        if self.nnpu:
            # nnPU learning
            if negative_risk.item() < -self.beta:
                objective = positive_risk - self.beta
                x_out = -self.gamma * negative_risk
            else:
                x_out = objective
        else:
            # uPU learning
            x_out = objective
        if self.objective:
            return objective, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg
        else:
            return x_out, objective, positive_risk, negative_risk, R_p_pos, R_u_neg, R_p_neg


# class FairPULoss(torch.nn.Module):
#     """
#     PU learning loss with fairness regularization.
#     Combines non-negative PU learning with demographic parity (DDP) or equality of opportunity (DEO).
#     """
#
#     def __init__(self, prior, loss_func=None, gamma=1, beta=0, nnpu=True,
#                  fairness='deo', lam_f=0.01, lam_penalty=0.001, penalty=True):
#         super(FairPULoss, self).__init__()
#         if not 0 < prior < 1:
#             raise ValueError("The class prior should be in (0, 1)")
#         self.prior = prior
#         self.loss_func = loss_func if loss_func is not None else lambda x: F.softplus(-x)
#         self.gamma = gamma
#         self.beta = beta
#         self.nnpu = nnpu
#         self.fairness = fairness
#         self.lam_f = lam_f
#         self.lam_penalty = lam_penalty
#         self.penalty = penalty
#
#     def compute_fairness_loss(self, output, target, sensitive, baselines=None):
#         """Computes fairness regularization loss (DDP or DEO)"""
#         kappa = lambda z: torch.relu(1 + z)
#         delta = lambda z: 1 - torch.relu(1 - z)
#         N = len(output)
#
#         # Convert to numpy for indexing
#         y = target.cpu().detach().numpy()
#         sensitive = sensitive.cpu().detach().numpy()
#
#         # Calculate class distributions
#         p0 = np.sum(sensitive == -1) / N
#         p1 = np.sum(sensitive == 1) / N
#         p11 = np.sum((sensitive == 1) & (y == 1)) / N
#         p01 = np.sum((sensitive == -1) & (y == 1)) / N
#         p10 = np.sum((sensitive == 1) & (y == -1)) / N
#         p00 = np.sum((sensitive == -1) & (y == -1)) / N
#
#         fairness_loss = 0
#         if self.fairness == 'ddp':
#             # Demographic Parity
#             idx_0 = np.where(sensitive == -1)[0]
#             idx_1 = np.where(sensitive == 1)[0]
#             pred_0 = output[idx_0]
#             pred_1 = output[idx_1]
#
#             pred_0_bin = torch.where(pred_0 < 0, 1, 0)
#             pred_1_bin = torch.where(pred_1 >= 0, 1, 0)
#
#             ddp = torch.sum(pred_1_bin) / p1 / N + torch.sum(pred_0_bin) / p0 / N - 1
#
#             if ddp.item() > 0:
#                 fairness_loss = ((torch.sum(kappa(output[idx_1]) / p1) +
#                                   torch.sum(kappa(-output[idx_0]) / p0)) / N - 1)
#             else:
#                 fairness_loss = -1 * ((torch.sum(delta(output[idx_1]) / p1) +
#                                        torch.sum(delta(-output[idx_0]) / p0)) / N - 1)
#
#         elif self.fairness == 'deo':
#             # Equality of Opportunity
#             idx_00 = np.where((sensitive == -1) & (y == -1))[0]
#             idx_01 = np.where((sensitive == -1) & (y == 1))[0]
#             idx_10 = np.where((sensitive == 1) & (y == -1))[0]
#             idx_11 = np.where((sensitive == 1) & (y == 1))[0]
#
#             pred_00 = output[idx_00]
#             pred_01 = output[idx_01]
#             pred_10 = output[idx_10]
#             pred_11 = output[idx_11]
#
#             pred_00_bin = torch.where(pred_00 < 0, 1, 0)
#             pred_01_bin = torch.where(pred_01 < 0, 1, 0)
#             pred_10_bin = torch.where(pred_10 >= 0, 1, 0)
#             pred_11_bin = torch.where(pred_11 >= 0, 1, 0)
#
#             deo = (torch.sum(pred_11_bin) / p11 + torch.sum(pred_01_bin) / p01 +
#                    torch.sum(pred_10_bin) / p10 + torch.sum(pred_00_bin) / p00) / N - 2
#
#             if deo.item() >= 0:
#                 fairness_loss = (torch.sum(kappa(output[idx_11]) / p11) +
#                                  torch.sum(kappa(-output[idx_01]) / p01) +
#                                  torch.sum(kappa(output[idx_10]) / p10) +
#                                  torch.sum(kappa(-output[idx_00]) / p00)) / N - 2
#             else:
#                 fairness_loss = -1 * (torch.sum(delta(output[idx_11]) / p11) +
#                                       torch.sum(delta(-output[idx_01]) / p01) +
#                                       torch.sum(delta(output[idx_10]) / p10) +
#                                       torch.sum(delta(-output[idx_00]) / p00)) / N - 2
#
#         fairness_loss = self.lam_f * fairness_loss
#
#         # Add performance penalty if enabled
#         if self.penalty:
#             pred_probs = torch.sigmoid(output)
#
#             # Calculate TPR and FPR for each group
#             def get_rates(idx):
#                 if len(idx) == 0:
#                     return torch.tensor(0.0).to(output.device)
#                 return pred_probs[idx].mean()
#
#             # Group -1
#             tpr0 = get_rates(idx_01)
#             fpr0 = get_rates(idx_00)
#
#             # Group 1
#             tpr1 = get_rates(idx_11)
#             fpr1 = get_rates(idx_10)
#
#             if baselines is None:
#                 baselines = (tpr1.item(), tpr0.item(), fpr1.item(), fpr0.item())
#
#             tpr1_baseline, tpr0_baseline, fpr1_baseline, fpr0_baseline = baselines
#
#             # Calculate penalties
#             tpr_penalty_1 = torch.max(torch.tensor(0.0).to(output.device),
#                                       tpr1_baseline - tpr1)
#             tpr_penalty_0 = torch.max(torch.tensor(0.0).to(output.device),
#                                       tpr0_baseline - tpr0)
#             fpr_penalty_1 = torch.max(torch.tensor(0.0).to(output.device),
#                                       fpr1 - fpr1_baseline)
#             fpr_penalty_0 = torch.max(torch.tensor(0.0).to(output.device),
#                                       fpr0 - fpr0_baseline)
#
#             penalty = tpr_penalty_1 + tpr_penalty_0 + fpr_penalty_1 + fpr_penalty_0
#             fairness_loss += self.lam_penalty * penalty
#
#             # Update baselines
#             baselines = (max(tpr1.item(), tpr1_baseline),
#                          max(tpr0.item(), tpr0_baseline),
#                          min(fpr1.item(), fpr1_baseline),
#                          min(fpr0.item(), fpr0_baseline))
#
#         return fairness_loss, baselines
#
#     def forward(self, x_p, x_u, t_p, t_u, s_p, s_u, baselines=None):
#         """
#         Forward pass computing both PU and fairness losses.
#
#         Args:
#             x_p: Positive samples
#             x_u: Unlabeled samples
#             t_p: Positive labels
#             t_u: Unlabeled labels
#             s_p: Sensitive attributes for positive samples
#             s_u: Sensitive attributes for unlabeled samples
#             baselines: Previous performance metrics for penalty term
#
#         Returns:
#             total_loss: Combined PU and fairness loss
#             pu_loss: PU learning component
#             fairness_loss: Fairness regularization component
#             baselines: Updated performance metrics
#         """
#         # Combine data for fairness computation
#         x_combined = torch.cat([x_p, x_u], dim=0)
#         t_combined = torch.cat([t_p, t_u], dim=0)
#         s_combined = torch.cat([s_p, s_u], dim=0)
#
#         # Compute standard PU loss
#         n_positive = max(1., t_p.sum().item())
#         n_unlabeled = max(1., (-t_u).sum().item())
#
#         y_p = self.loss_func(x_p)
#         y_u = self.loss_func(-x_p)
#         y_unlabeled = self.loss_func(-x_u)
#
#         positive_risk = torch.sum(self.prior / n_positive * y_p)
#         negative_risk = torch.sum(y_unlabeled / n_unlabeled) - torch.sum(
#             self.prior / n_positive * y_u)
#
#         if self.nnpu and negative_risk.item() < -self.beta:
#             pu_loss = positive_risk - self.beta
#         else:
#             pu_loss = positive_risk + negative_risk
#
#         # Compute fairness regularization
#         fairness_loss, updated_baselines = self.compute_fairness_loss(
#             x_combined, t_combined, s_combined, baselines)
#
#         # Combine losses
#         total_loss = pu_loss + fairness_loss
#
#         return total_loss, pu_loss, fairness_loss, updated_baselines


# adapted from https://github.com/valerystrizh/pytorch-histogram-loss
class LabelDistributionLoss(nn.Module):
    def __init__(self, prior, device, num_bins=1, proxy='polar', dist='L1'):
        super(LabelDistributionLoss, self).__init__()
        self.prior = prior
        self.frac_prior = 1.0 / (2 * prior)

        self.step = 1 / num_bins  # bin width. predicted scores in [0, 1].
        self.device = device
        self.t = torch.arange(0, 1 + self.step, self.step).view(
            1, -1).requires_grad_(False)  # [0, 1+bin width)
        self.t_size = num_bins + 1

        self.dist = None
        if dist == 'L1':
            self.dist = F.l1_loss
        else:
            raise NotImplementedError(
                "The distance: {} is not defined!".format(dist))

        # proxy
        proxy_p, proxy_n = None, None
        if proxy == 'polar':
            proxy_p = np.zeros(self.t_size, dtype=float)
            proxy_n = np.zeros_like(proxy_p)
            proxy_p[-1] = 1
            proxy_n[0] = 1
        else:
            raise NotImplementedError(
                "The proxy: {} is not defined!".format(proxy))

        proxy_mix = prior * proxy_p + (1 - prior) * proxy_n
        print('#### Label Distribution Loss ####')
        print('ground truth P:')
        print(proxy_p)
        print('ground truth U:')
        print(proxy_mix)

        # to torch tensor
        self.proxy_p = torch.from_numpy(proxy_p).requires_grad_(False).float()
        self.proxy_mix = torch.from_numpy(proxy_mix).requires_grad_(
            False).float()

        # to device
        self.t = self.t.to(self.device)
        self.proxy_p = self.proxy_p.to(self.device)
        self.proxy_mix = self.proxy_mix.to(self.device)

    def histogram(self, scores):
        scores_rep = scores.repeat(1, self.t_size)

        hist = torch.abs(scores_rep - self.t)

        inds = (hist > self.step)
        hist = self.step - hist  # switch values
        hist[inds] = 0

        return hist.sum(dim=0) / (len(scores) * self.step)

    def forward(self, outputs, labels):
        scores = torch.sigmoid(outputs)

        scores = scores.view_as(labels)
        s_p = scores[labels == 1].view(-1, 1)
        s_u = scores[labels == 0].view(-1, 1)

        l_p = 0
        l_u = 0
        if s_p.numel() > 0:
            hist_p = self.histogram(s_p)
            l_p = self.dist(hist_p, self.proxy_p, reduction='mean')
        if s_u.numel() > 0:
            hist_u = self.histogram(s_u)
            l_u = self.dist(hist_u, self.proxy_mix, reduction='mean')

        return l_p + self.frac_prior * l_u


def label_distribution_loss(x, t, prior):
    return LabelDistributionLoss(prior, device=x.device)(x, t)
