import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class SimCLRLoss(nn.Layer):
    """Modified from https://github.com/wvangansbeke/Unsupervised-Classification."""

    def __init__(self, temperature, reduction="mean"):
        super(SimCLRLoss, self).__init__()
        self.temperature = temperature
        self.reduction = reduction

    def forward(self, features):
        """
        input:
            - features: hidden feature representation of shape [b, 2, dim]
        output:
            - loss: loss computed according to SimCLR
        """

        b, n, dim = features.shape
        assert n == 2
        mask = paddle.eye(b, dtype=paddle.float32)

        contrast_features = paddle.concat(paddle.unbind(features, axis=1), axis=0)
        anchor = features[:, 0]

        # Dot product
        dot_product = paddle.matmul(anchor, contrast_features.T) / self.temperature

        # Log-sum trick for numerical stability
        logits_max = paddle.max(dot_product, axis=1, keepdim=True)
        logits = dot_product - logits_max.detach()

        mask = mask.tile([1, 2])
        logits_mask = paddle.ones_like(mask)
        for i in range(b):
            logits_mask[i][i] = 0
        mask = mask * logits_mask

        # Log-softmax
        exp_logits = paddle.exp(logits) * logits_mask
        log_prob = logits - paddle.log(exp_logits.sum(1, keepdim=True))

        # Mean log-likelihood for positive
        if self.reduction == "mean":
            loss = -((mask * log_prob).sum(1) / mask.sum(1)).mean()
        elif self.reduction == "none":
            loss = -((mask * log_prob).sum(1) / mask.sum(1))
        else:
            raise ValueError("The reduction must be mean or none!")

        return loss


class RCELoss(nn.Layer):
    """Reverse Cross Entropy Loss."""

    def __init__(self, num_classes=10, reduction="mean"):
        super(RCELoss, self).__init__()
        self.num_classes = num_classes
        self.reduction = reduction

    def forward(self, x, target):
        prob = F.softmax(x, axis=-1)
        prob = paddle.clip(prob, min=1e-7, max=1.0)
        one_hot = F.one_hot(target, self.num_classes)
        one_hot = paddle.clip(one_hot, min=1e-4, max=1.0)
        loss = -1 * paddle.sum(prob * paddle.log(one_hot), axis=-1)
        if self.reduction == "mean":
            loss = loss.mean()

        return loss


class SCELoss(nn.Layer):
    """Symmetric Cross Entropy."""

    def __init__(self, alpha=0.1, beta=1, num_classes=10, reduction="mean"):
        super(SCELoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.num_classes = num_classes
        self.reduction = reduction

    def forward(self, x, target):
        ce = paddle.nn.CrossEntropyLoss(reduction=self.reduction)
        rce = RCELoss(num_classes=self.num_classes, reduction=self.reduction)
        ce_loss = ce(x, target)
        rce_loss = rce(x, target)
        loss = self.alpha * ce_loss + self.beta * rce_loss

        return loss


class MixMatchLoss(nn.Layer):
    """SemiLoss in MixMatch.

    Modified from https://github.com/YU1ut/MixMatch-pytorch/blob/master/train.py.
    """

    def __init__(self, rampup_length, lambda_u=75):
        super(MixMatchLoss, self).__init__()
        self.rampup_length = rampup_length
        self.lambda_u = lambda_u
        self.current_lambda_u = lambda_u

    def linear_rampup(self, epoch):
        if self.rampup_length == 0:
            return 1.0
        else:
            current = np.clip(epoch / self.rampup_length, 0.0, 1.0)
            self.current_lambda_u = float(current) * self.lambda_u

    def forward(self, xoutput, xtarget, uoutput, utarget, epoch):
        self.linear_rampup(epoch)
        uprob = F.softmax(uoutput, axis=1)
        Lx = -paddle.mean(paddle.sum(F.log_softmax(xoutput, axis=1) * xtarget, axis=1))
        Lu = paddle.mean((uprob - utarget) ** 2)

        return Lx, Lu, self.current_lambda_u
