import torch
import torch.nn.functional as F
from torch import autograd, nn


class DSAMLoss(nn.Module):
    def __init__(self, k=4, gama=0.7, margin=1, **kwargs):
        super(DSAMLoss, self).__init__()
        self.name = 'dsam_loss'
        self.k = k
        self.gama = gama
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=self.margin, reduction='none')

    def forward(self, features, labels):
        inputs = features
        targets = labels
        # features shape ==> N*[M*1], M means embedding size
        N = targets.size()[0]
        M = inputs.size()[1]
        assert inputs.size()[0] == N, "Batch size donesn't match!"
        # inputs normalize
        normal_data = F.normalize(inputs, p=2, dim=1)
        # get all pos and neg points, masks
        Y = targets.unsqueeze(1)
        Aff = 2 * (Y == Y.t()).type(torch.float) - 1
        Aff -= (torch.eye(N, N).byte()).cuda()
        positive = (Aff > 0).type(torch.bool)
        negetive = (Aff < 0).type(torch.bool)
        # get all dist2 and cosin affinity, [-1, 1]
        cos_affin_normal = torch.mm(normal_data, normal_data.t())
        # why abs? cause the float in cuda, when the num is very closed to zero, it will be negetive
        # add e^? e^all_normal_dist2-1? ==> make line more smooth
        all_normal_dist2 = torch.exp(torch.abs(2 - 2 * cos_affin_normal))-1
        # get neg and pos dist2, shape N*N
        all_indx = torch.arange(0, N).expand(N, N).cuda()
        all_pos = torch.masked_select(all_normal_dist2, positive)
        all_pos = all_pos.reshape(N, -1)
        all_pos_indx = torch.masked_select(all_indx, positive)
        all_pos_indx = all_pos_indx.reshape(N, -1)
        all_neg = torch.masked_select(all_normal_dist2, negetive)
        all_neg = all_neg.reshape(N, -1)
        all_neg_indx = torch.masked_select(all_indx, negetive)
        all_neg_indx = all_neg_indx.reshape(N, -1)
        # choose nearest negetive and furtherest postive
        # ==> maybe will use eula distance later.
        nearest_neg, _ = torch.min(all_neg, dim=1, keepdim=True)
        furtherest_pos, _ = torch.max(all_pos, dim=1, keepdim=True)
        # get k pos and indx
        try:
            _, _sorted_pos_indx = torch.sort(all_pos, dim=1)
            _topk_indx = _sorted_pos_indx[:, :self.k]
            topk_indx = torch.gather(all_pos_indx, dim=1, index=_topk_indx)
        except Exception as err:
            print('Please check your k, k have to less than imgs per id!')
            raise ValueError('The K is not correct:'+str(err))
        # get the theta and dist attri ==> sum_k((a-b)^2/a^2+b^2)
        # topk_features = normal_data[topk_indx]
        # expand_features = normal_data.expand(self.k, N, M)
        topk_features = inputs[topk_indx]
        expand_features = inputs.expand(self.k, N, M)
        expand_features = expand_features.permute(1, 0, 2)
        axis_probit_2 = torch.pow(topk_features - expand_features, 2)
        probit_2 = torch.sum(axis_probit_2, dim=-1)
        # changed!
        # k_loss_matrix = torch.sum(torch.div(probit_2, 2*self.k), dim=-1)
        k_loss_matrix = torch.sum(probit_2, dim=-1)
        # get the margin rank loss ==> loss(x1,x2,y) = max(0, -y*(x1-x2)+margin)
        rank_tragets = torch.ones_like(nearest_neg).cuda()
        trans_all_neg = all_neg.permute(1, 0)
        # all_neg_pos = [self.ranking_loss(trans_all_neg[i].cuda(),
        #                                  furtherest_pos.cuda(), rank_tragets.cuda()) for i in range(trans_all_neg.shape[0])]
        all_neg_pos = self.ranking_loss(all_neg,furtherest_pos, rank_tragets)
        rank_loss_sum = torch.sum(torch.mean(all_neg_pos, -1))
        # rank_loss_sum = sum(all_neg_pos)
        # rank_loss_mean = self.ranking_loss(
        #    nearest_neg.cuda(), furtherest_pos.cuda(), rank_tragets.cuda())
        final_loss = torch.mean(torch.sqrt(
            k_loss_matrix)) + self.gama * rank_loss_sum
        return final_loss


'''
if __name__ == '__main__':
    #os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    test_data = torch.tensor([[0.3274, 0.7735, 0.9956, 0.7842, 0.4577, 0.6470, 0.0823, 0.9673, 0.3825,
                               0.7163],
                              [0.0315, 0.6901, 0.9562, 0.1160, 0.8447, 0.5705, 0.7071, 0.7587, 0.1891,
                               0.1761],
                              [0.8884, 0.2450, 0.2447, 0.5723, 0.5399, 0.0726, 0.8185, 0.9131, 0.3868,
                               0.5923],
                              [0.2351, 0.9697, 0.3939, 0.4850, 0.9209, 0.9906, 0.3148, 0.5653, 0.6483,
                               0.8033],
                              [0.8049, 0.7296, 0.1281, 0.8276, 0.7527, 0.9656, 0.6210, 0.4035, 0.2817,
                               0.3487],
                              [0.6257, 0.5918, 0.1143, 0.4469, 0.6396, 0.0085, 0.8807, 0.2351, 0.7974,
                               0.9814],
                              [0.5231, 0.6928, 0.2384, 0.3017, 0.4021, 0.2700, 0.7104, 0.3938, 0.1699,
                               0.6352],
                              [0.0192, 0.4571, 0.9019, 0.6188, 0.3298, 0.3388, 0.6135, 0.4648, 0.0384,
                               0.1048],
                              [0.2673, 0.3736, 0.4060, 0.1326, 0.5498, 0.4194, 0.7090, 0.8811, 0.8007,
                               0.9837],
                              [0.4428, 0.3966, 0.0393, 0.8389, 0.9310, 0.1132, 0.7999, 0.4951, 0.5424,
                               0.7863],
                              [0.5766, 0.5799, 0.1552, 0.6766, 0.3904, 0.0461, 0.1842, 0.1944, 0.7140,
                               0.0719],
                              [0.5579, 0.5411, 0.1006, 0.6253, 0.4840, 0.7095, 0.0425, 0.2194, 0.7106,
                               0.0101],
                              [0.8706, 0.8008, 0.7208, 0.2794, 0.4129, 0.4737, 0.1588, 0.3700, 0.2302,
                               0.2864],
                              [0.1609, 0.9241, 0.5703, 0.4612, 0.3916, 0.7450, 0.9096, 0.6010, 0.2301,
                               0.5432],
                              [0.7445, 0.8887, 0.6237, 0.9825, 0.1251, 0.9804, 0.0224, 0.3665, 0.6042,
                               0.5760]], requires_grad=True)
    labels = torch.tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
    loss_func = DSAMLoss(k=3, gama=0.5)
    loss = loss_func(test_data.cuda(), labels.cuda())
    loss.backward()
    print(loss)
'''