from datetime import datetime
import logging
import os
import sys
import torch
import numpy as np

from tensorboardX import SummaryWriter


def cdists(batch):
    '''
    batch: the size of (N, M)
            N: number of images
            M: number of classes
    '''
    diff = torch.unsqueeze(batch, 0) - torch.unsqueeze(batch, 1)
    diff_2 = torch.sum(diff*diff, axis=-1)

    return diff_2


def cosine_sim(batch):
    
    v = torch.nn.functional.cosine_similarity(torch.unsqueeze(batch, 0), torch.unsqueeze(batch, 1), dim=-1)

    return v


def contrastive_loss(batch, labels, temperature=0.1):

    dists_v = cosine_sim(batch) / temperature
    dists = torch.exp(dists_v)

    same_iden_ = (torch.unsqueeze(labels,0) == torch.unsqueeze(labels,1))
    other_iden = ~same_iden_

    if torch.version.cuda.find('10.2') != -1:
        itself = ~torch.eye(same_iden_.size(0), dtype=torch.bool).cuda()
    else:
        itself = torch.ones(same_iden_.shape, dtype=torch.bool)
        a = torch.eye(same_iden_.size(0))
        itself[a == 1] = False
        itself = itself.cuda()

    same_iden = same_iden_ & itself     # ignore itself

    n = torch.sum(same_iden)

    if n == 0:
        return torch.zeros(()).cuda()

    pos_dists = torch.where(same_iden, dists, torch.Tensor([0]).cuda())

    other_dists = torch.where(other_iden, dists, torch.Tensor([0]).cuda())
    other_sum = torch.sum(other_dists, dim=-1).unsqueeze(1)

    denominator = pos_dists + other_sum

    v = pos_dists / denominator
    v[v==0] = 1e-16     # to prevent log(0) --> Math error
    v_ln = torch.log(v)

    loss = - torch.sum(v_ln[same_iden]) / n

    return loss


def batchhard(batch, idens, margin=0.1, n=None, soft=False):

    dists = cdists(batch)

    same_iden = (torch.unsqueeze(idens,0) == torch.unsqueeze(idens,1))
    other_iden = ~same_iden

    infs = torch.ones_like(dists)*torch.Tensor([float('inf')]).cuda()

    dists_pos = torch.where(same_iden, dists, -infs)
    pos = torch.max(dists_pos, axis=1).values

    dists_neg = torch.where(other_iden, dists, infs)
    neg = torch.min(dists_neg, axis=1).values

    diff = (pos + margin) - neg
    if soft:
        diff = torch.log(torch.exp(diff)+1)
    else:
        diff[diff<0] = 0

    if n is None:
        return torch.mean(diff)
    else:
        return torch.sum(diff)*1.0/n


def batchmean(batch, idens, margin=0.1, n=None, soft=False):

    dists = cdists(batch)

    same_iden = (torch.unsqueeze(idens,0) == torch.unsqueeze(idens,1))
    other_iden = ~same_iden

    dists_pos = torch.where(same_iden, dists, torch.Tensor([0]).cuda())
    pos = torch.mean(dists_pos, axis=1)

    dists_neg = torch.where(other_iden, dists, torch.Tensor([0]).cuda())
    neg = torch.mean(dists_neg, axis=1)

    diff = (pos + margin) - neg
    if soft:
        diff = torch.log(torch.exp(diff)+1)
    else:
        diff[diff<0] = 0
    
    if n is None:
        return torch.mean(diff)
    else:
        return torch.sum(diff)*1.0/n


def _get_triplet_mask(labels):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    """
    # Check that i, j and k are distinct
    indices_equal = torch.eye(labels.size(0)).bool().cuda()
    indices_not_equal = ~indices_equal
    i_not_equal_j = indices_not_equal.unsqueeze(2)
    i_not_equal_k = indices_not_equal.unsqueeze(1)
    j_not_equal_k = indices_not_equal.unsqueeze(0)

    distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k


    label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
    i_equal_j = label_equal.unsqueeze(2)
    i_equal_k = label_equal.unsqueeze(1)

    valid_labels = ~i_equal_k & i_equal_j

    return valid_labels & distinct_indices


def batch_all(batch, idens, margin=0.1, n=None, soft=False):
    """Build the triplet loss over a batch of embeddings.

    We generate all the valid triplets and average the loss over the positive ones.

    Args:
        idens: labels of the batch, of size (batch_size,)
        batch: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    """
    # Get the pairwise distance matrix

    pairwise_dist = cdists(batch)

    anchor_positive_dist = pairwise_dist.unsqueeze(2)
    anchor_negative_dist = pairwise_dist.unsqueeze(1)

    # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
    # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
    # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
    # and the 2nd (batch_size, 1, batch_size)

    triplet_loss = anchor_positive_dist - anchor_negative_dist + margin

    # Put to zero the invalid triplets
    # (where label(a) != label(p) or label(n) == label(a) or a == p)

    mask = _get_triplet_mask(idens)
    triplet_loss = mask.float() * triplet_loss

    # Remove negative losses (i.e. the easy triplets)
    triplet_loss[triplet_loss < 0] = 0

    # Count number of positive triplets (where triplet_loss > 0)
    valid_triplets = triplet_loss[triplet_loss > 1e-16]
    num_positive_triplets = valid_triplets.size(0)
    num_valid_triplets = mask.sum()

    # Get final mean triplet loss over the positive valid triplets
    triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)

    return triplet_loss
