
import torch as th
import torch.nn as nn
from torch.nn.parameter import Parameter
from pytorch_metric_learning import losses as metric_learning_losses
import torch.nn.functional as F

def weighted_loss(loss, y_batch, reweight_factors):
    min_factor, maj_factor = reweight_factors
    min_idx = y_batch == 1
    maj_idx = th.logical_not(min_idx)
    weights = min_idx * min_factor + maj_idx * maj_factor
    return th.mean(weights * loss)

def weighted_loss_mean(loss, y_batch, reweight_factors):
    min_factor, maj_factor = reweight_factors
    min_idx = y_batch == 1
    maj_idx = th.logical_not(min_idx)
    weights = (min_idx * min_factor + maj_idx * maj_factor).to(loss.device)
    return th.mean(weights * loss)

def weighted_mse_loss(pred, target, y_batch, reweight_factors):
    min_factor, maj_factor = reweight_factors
    min_idx = y_batch == 1
    maj_idx = th.logical_not(min_idx)
    weights = (min_idx * min_factor + maj_idx * maj_factor).to(pred.device)
    return th.mean(weights * th.mean((pred - target) ** 2, dim=1))

class MetricLearnLoss(nn.Module):
    def __init__(self, loss_type, device, latent_dim, reweight_loss, reweight_factors, label_smoothing):
        super().__init__()
        self.label_smoothing = label_smoothing
        self.device = device
        self.loss_type = loss_type
        self.latent_dim = latent_dim
        self.centers_per_class = 1
        self.reweight_loss = reweight_loss
        self.reweight_factors = reweight_factors
        if self.loss_type == 'const_center':
            self.center_loss = self.option_1_loss
            self.min_center = Parameter(th.ones(self.latent_dim, device=self.device) * 1.5)
            self.maj_center = Parameter(th.ones(self.latent_dim, device=self.device) * (-1.5))
            self.normalize = False
            self.lr = 0.0
        elif self.loss_type == 'option_2':
            self.center_loss = self.option_2_loss
            self.min_center = Parameter(th.ones(self.latent_dim, device=self.device) * 0.5)
            self.maj_center = Parameter(th.ones(self.latent_dim, device=self.device) * (-0.5))
            self.linear_classifier = nn.Linear(latent_dim, 2, device=self.device)
            self.normalize = False
            self.lr = 0.005
        elif self.loss_type == 'ce':
            self.center_loss = self.option_3_loss
            self.min_center = Parameter(th.ones(self.latent_dim, device=self.device) * 1.5)
            self.maj_center = Parameter(th.ones(self.latent_dim, device=self.device) * (-1.5))
            self.normalize = False
            self.lr = 0.005
        elif self.loss_type == 'linear_plus_ce':
            self.center_loss = self.linear_plus_ce
            self.linear_classifier = nn.Linear(latent_dim, 2, device=self.device)
            self.normalize = False
            self.lr = 0.05
        elif self.loss_type == 'ce_plus_center':
            self.center_loss = self.option_5_loss
            self.min_center = Parameter(th.ones(self.latent_dim, device=self.device) * 1.5)
            self.maj_center = Parameter(th.ones(self.latent_dim, device=self.device) * (-1.5))
            self.normalize = False
            self.lr = 0.005
        elif self.loss_type == 'linear_plus_sgd_centers':
            self.center_loss = self.option_6_loss
            self.min_center = Parameter(th.ones(self.latent_dim, device=self.device) * 1.5)
            self.maj_center = Parameter(th.ones(self.latent_dim, device=self.device) * (-1.5))
            self.linear_classifier = nn.Linear(latent_dim, 2, device=self.device)
            self.normalize = False
            self.lr = 0.05
        elif self.loss_type == 'normalized_softmax':
            self.center_loss = metric_learning_losses.NormalizedSoftmaxLoss(num_classes=2, embedding_size=self.latent_dim) #num_classes=2, embedding_size=self.latent_dim)
            self.normalize = True
            self.lr = 0.005
        elif self.loss_type == 'soft_triple':
            self.center_loss = SoftTriple(la=10, gamma=0.1, tau=0.0, margin=0,dim=self.latent_dim, cN=2, K=2, device=self.device)
            self.normalize = False
            self.lr = 0.05
        elif self.loss_type == 'focal':
            self.center_loss = self.focal_v2  # self.sigmoid_focal_loss
            self.linear_classifier = nn.Linear(latent_dim, 2, device=self.device)
            self.alpha = th.tensor([0.25, 0.75]).to(self.device)  # 0.25
            self.gamma = 2
            self.reduction = "mean"
            self.normalize = False
            self.lr = 0.005
        else:
            raise Exception('center loss option not supported')

    ## OPTION 1
    def option_1_loss(self, z, y_batch):
        centers = th.zeros((0, self.latent_dim), device=self.device)
        for i in range(y_batch.shape[0]):
            if y_batch[i] == 0:
                centers = th.cat((centers, self.maj_center[None, :]), 0)
            else:
                centers = th.cat((centers, self.min_center[None, :]), 0)
        if self.reweight_loss:
            return weighted_mse_loss(z, centers, y_batch, self.reweight_factors)
        else:
            return nn.functional.mse_loss(z, centers)

    ## OPTION 2
    def option_2_loss(self, z, y_batch):
        # center
        centers = th.zeros((0, self.latent_dim), device=self.device)
        for i in range(y_batch.shape[0]):
            if y_batch[i] == 0:
                centers = th.cat((centers, self.maj_center[None, :]), 0)
            else:
                centers = th.cat((centers, self.min_center[None, :]), 0)
        center_loss = nn.functional.mse_loss(z, centers)
        # classify
        scores = self.linear_classifier(z)
        classify_loss = nn.functional.cross_entropy(scores, y_batch.type(th.LongTensor).to(self.device))

        return center_loss + classify_loss

    ## OPTION 3
    def option_3_loss(self, z, y_batch):
        cos = nn.CosineSimilarity(dim=1, eps=1e-08)
        d_min_center = z.matmul(self.min_center.t())  #cos(z, self.min_center)  #z.matmul(self.min_center.t())
        d_maj_center = z.matmul(self.maj_center.t())  #cos(z, self.maj_center)  #z.matmul(self.maj_center.t())
        distances = th.concat((d_maj_center[None, :], d_min_center[None, :]), dim=0).t()
        if self.reweight_loss:
            b4_weight = nn.functional.cross_entropy(-distances, y_batch.type(th.LongTensor).to(self.device), reduction='none')
            contrastive_center_loss = weighted_loss_mean(b4_weight, y_batch, self.reweight_factors)
        else:
            contrastive_center_loss = nn.functional.cross_entropy(-distances, y_batch.type(th.LongTensor).to(self.device))
        return contrastive_center_loss

    ## Only linear classifier (CE)
    def linear_plus_ce(self, z, y_batch):
        # classify
        scores = self.linear_classifier(z)
        classify_loss = nn.functional.cross_entropy(scores, y_batch, label_smoothing=self.label_smoothing)
        return classify_loss

    ## OPTION 5
    def option_5_loss(self, z, y_batch, epoch_centers):
        d_min_center = z.matmul(self.min_center.t())
        d_maj_center = z.matmul(self.maj_center.t())
        distances = th.concat((d_maj_center[None, :], d_min_center[None, :]), dim=0).t()
        if self.reweight_loss:
            b4_weight = nn.functional.cross_entropy(-distances, y_batch.type(th.LongTensor).to(self.device), reduction='none')
            contrastive_center_loss = weighted_loss_mean(b4_weight, y_batch, self.reweight_factors)
        else:
            contrastive_center_loss = nn.functional.cross_entropy(-distances, y_batch.type(th.LongTensor).to(self.device))

        # center
        maj_center, min_center = epoch_centers
        centers = th.zeros((0, self.latent_dim), device=self.device)
        for i in range(y_batch.shape[0]):
            if y_batch[i] == 0:
                centers = th.cat((centers, maj_center[None, :]), 0)
            else:
                centers = th.cat((centers, min_center[None, :]), 0)
        center_loss = nn.functional.mse_loss(z, centers)

        return contrastive_center_loss + center_loss

    ## OPTION 6
    def option_6_loss(self, z, y_batch):
        # classify
        scores = self.linear_classifier(z)
        classify_loss = nn.functional.cross_entropy(scores, y_batch.type(th.LongTensor).to(self.device))
        # center
        centers = th.zeros((0, self.latent_dim), device=self.device)
        for i in range(y_batch.shape[0]):
            if y_batch[i] == 0:
                centers = th.cat((centers, self.maj_center[None, :]), 0)
            else:
                centers = th.cat((centers, self.min_center[None, :]), 0)
        center_loss = nn.functional.mse_loss(z, centers)
        return classify_loss + center_loss


    def sigmoid_focal_loss(self, z, y):
        """
        Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

        Args:
            inputs (Tensor): A float tensor of arbitrary shape.
                    The predictions for each example.
            targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
                    classification label for each element in inputs
                    (0 for the negative class and 1 for the positive class).
            alpha (float): Weighting factor in range (0,1) to balance
                    positive vs negative examples or -1 for ignore. Default: ``0.25``.
            gamma (float): Exponent of the modulating factor (1 - p_t) to
                    balance easy vs hard examples. Default: ``2``.
            reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                    ``'none'``: No reduction will be applied to the output.
                    ``'mean'``: The output will be averaged.
                    ``'sum'``: The output will be summed. Default: ``'none'``.
        Returns:
            Loss tensor with the reduction option applied.
        """
        y = th.zeros(y.shape[0], 2).type(y.type()).scatter_(1, y.reshape(y.shape[0], 1), 1).float()
        scores = self.linear_classifier(z)
        p = th.sigmoid(scores)
        ce_loss = F.binary_cross_entropy_with_logits(scores, y, reduction="none")
        p_t = p * y + (1 - p) * (1 - y)
        loss = ce_loss * ((1 - p_t) ** self.gamma)
        if self.alpha >= 0:
            alpha_t = self.alpha * y + (1 - self.alpha) * (1 - y)
            loss = alpha_t * loss
        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss

    def focal_v2(self, z, y):
        logits = self.linear_classifier(z)
        logpt = F.log_softmax(logits, dim=1)
        pt = th.exp(logpt)
        logpt = (1-pt)**self.gamma * logpt
        loss = F.nll_loss(logpt, y, self.alpha)
        return loss

    def forward(self, z, y_batch):
        y_batch = y_batch.type(th.LongTensor).to(self.device)
        if self.normalize:
            z = F.normalize(z)
        return self.center_loss(z, y_batch)

    """
    # Soft Triplet
    self.soft_triplet_loss = SoftTriple(la=self.hparams['la'], gamma=self.hparams['gamma'], tau=0.01,
                                        margin=self.hparams['margin'], dim=self.latent_dim,
                                        cN=2, K=self.hparams['K'], device=self.device)
    self.soft_triplet_centers_optim = torch.optim.Adam(self.soft_triplet_loss.parameters(), lr=0.1)
    """
    # Soft Triplet
    # SoftMax with Temperature T : SoftMax(x/T)
    # Larger T's -> Larger output variance (less confidence)
    #
    # tau = factor of regularization for adaptive number of centers (prioritize center merges)
    # gamma = x_i class similarity temperature (max over similarities to each of the class centers)
    #         Larger -> more intra-class variation / smoothing  (less concentrated around each cluster center)
    # lambda = 1/temperature  of the tripleLoss.
    #          Affects inter-class variation. Larger -> lower temperature -> more inter-class separation
    # margin = delta - from the triplet constraint (S_i,y - S_i,* >= delta)

    """
    # soft triple
    parser.add_argument('-lambda_soft_triplet', type=float, default=0.0)
    parser.add_argument('-la', type=float, default=25)
    parser.add_argument('-gamma', type=float, default=0.1)
    parser.add_argument('-margin', type=float, default=0.1)
    parser.add_argument('-K', type=int, default=10)
    """

    # Implementation of SoftTriple Loss
    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.nn.parameter import Parameter
    from torch.nn import init

    class SoftTriple(nn.Module):
        def __init__(self, la, gamma, tau, margin, dim, cN, K, device):
            super(SoftTriple, self).__init__()
            self.la = 1. / la  # added 1/
            self.gamma = 1. / gamma
            self.tau = tau
            self.margin = margin
            self.cN = cN
            self.K = K
            self.device = device
            self.fc = Parameter(torch.ones((dim, cN * K), dtype=torch.float).to(self.device))
            self.weight = torch.zeros(cN * K, cN * K, dtype=torch.bool).to(self.device)
            for i in range(0, cN):
                for j in range(0, K):
                    self.weight[i * K + j, i * K + j + 1:(i + 1) * K] = 1
            init.kaiming_uniform_(self.fc, a=math.sqrt(5))
            self.centers = F.normalize(self.fc, p=2, dim=0)
            return

        def forward(self, input, target):
            """
            self.centers = F.normalize(self.fc, p=2, dim=0)
            input = F.normalize(input, p=2, dim=0)               # MY - I added this line
            """
            self.centers = self.fc

            #
            simInd = input.matmul(self.centers)  # x * W  = linear layer -> Nxdim * dimx(cN*K) = Nx(cN*K)
            simStruc = simInd.reshape(-1, self.cN, self.K)  # Nx(Cn*K) -> N x cN x K
            prob = F.softmax(simStruc * self.gamma, dim=2)  # softmax over K of the same class
            simClass = torch.sum(prob * simStruc, dim=2)  # sum K probabilities of the same class ->  N x cN
            marginM = torch.zeros(simClass.shape).to(self.device)
            marginM[torch.arange(0, marginM.shape[0]), target] = self.margin
            lossClassify = F.cross_entropy(self.la * (simClass - marginM), target,
                                           label_smoothing=0.1)  # Added label smoothing
            if self.tau > 0 and self.K > 1:
                simCenter = self.centers.t().matmul(self.centers)
                # reg = torch.sum(torch.sqrt(2.0+1e-5-2.*simCenter[self.weight]))/(self.cN*self.K*(self.K-1.))
                reg = -torch.sum(simCenter[self.weight] / (self.cN * self.K * (self.K - 1.)))
                return lossClassify + self.tau * reg
            else:
                return lossClassify