import torch
import torch.nn.functional as F

import torch.distributed as dist

def dual_moco_loss_func(
    query1: torch.Tensor, key1: torch.Tensor, queue1: torch.Tensor, distill1: torch.Tensor, query2: torch.Tensor, key2: torch.Tensor, queue2: torch.Tensor, distill2: torch.Tensor, temperature=0.1, distill_temperature=0.1, loss_alpha = 0.5, ablation_wo_p_queue = False, ablation_wo_s_queue = False, 
) -> torch.Tensor:
    """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a
    queue of past elements.

    Args:
        query (torch.Tensor): NxD Tensor containing the queries from view 1.
        key (torch.Tensor): NxD Tensor containing the queries from view 2.
        queue (torch.Tensor): a queue of negative samples for the contrastive loss.
        temperature (float, optional): [description]. temperature of the softmax in the contrastive
            loss. Defaults to 0.1.

    Returns:
        torch.Tensor: MoCo loss.
    """

#     pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
#     neg = torch.einsum("nc,ck->nk", [query, queue])
#     logits = torch.cat([pos, neg], dim=1)
#     logits /= temperature
#     targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long)

    l_pos1 = torch.einsum('nc,nc->n', [query1, key1]).unsqueeze(-1)
    if ablation_wo_p_queue == False:
        l_neg1 = torch.einsum('nc,ck->nk', [query1, torch.cat([queue1, queue2], dim = 1)])
    else:
        l_neg1 = torch.einsum('nc,ck->nk', [query1, torch.cat([queue1], dim = 1)])
        

    # logits: Nx(1+K)
    logits1 = torch.cat([l_pos1, l_neg1], dim=1)

    # apply temperature
    logits1 /= temperature

    # labels: positive key indicators
    labels = torch.zeros(logits1.shape[0], dtype=torch.long).cuda()

    loss1 = F.cross_entropy(logits1, labels)

    l_pos2 = torch.einsum('nc,nc->n', [distill1, query2]).unsqueeze(-1)
    if ablation_wo_s_queue == False:
        l_neg2 = torch.einsum('nc,ck->nk', [distill1, torch.cat([queue1, queue2], dim = 1)])
    else:
        l_neg2 = torch.einsum('nc,ck->nk', [distill1, torch.cat([queue2], dim = 1)])
        
    
    # logits: Nx(1+K)
    logits2 = torch.cat([l_pos2, l_neg2], dim=1)

    # apply temperature
    logits2 /= distill_temperature

    loss2 = F.cross_entropy(logits2, labels)
    
    loss = loss1 * 2 *(loss_alpha) + loss2 * 2* (1 - loss_alpha)

    return loss
    
#     loss = loss_alpha*loss1 + (1-loss_alpha)*loss2

def simclr_dual_loss_func(
    p1: torch.Tensor,
    p2: torch.Tensor,
    z1: torch.Tensor,
    z2: torch.Tensor,
    f1: torch.Tensor,
    f2: torch.Tensor,
    temperature: float = 0.1,
    labels: torch.Tensor = None,
) -> torch.Tensor:
    
    device = z1.device

    b = z1.size(0)
#     z = torch.cat((z1, z2), dim=0)
#     z = F.normalize(z, dim=-1)

    p = F.normalize(torch.cat([p1, p2], dim=0), dim=-1)
    z = F.normalize(torch.cat([z1, z2], dim=0), dim=-1)
    f = F.normalize(torch.cat([f1, f2], dim=0), dim=-1)

    logits1 = torch.einsum("if, jf -> ij", z, z) / temperature
    logits1_prev = torch.einsum("if, jf -> ij", z, f) / temperature
    
    logits2 = torch.einsum("if, jf -> ij", p, f) / temperature
    logits2_prev = torch.einsum("if, jf -> ij", p, z) / temperature
    
    logits1_cat = torch.cat((logits1, logits1_prev), dim = 1)
    logits2_cat = torch.cat((logits2, logits2_prev), dim = 1)
    
    
    logits1_max, _ = torch.max(logits1_cat, dim=1, keepdim=True)
    logits1_cat = logits1_cat - logits1_max.detach()
    
    logits2_max, _ = torch.max(logits2_cat, dim=1, keepdim=True)
    logits2_cat = logits2_cat - logits2_max.detach()
    
    pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
    pos_mask[:, b:].fill_diagonal_(True)
    pos_mask[b:, :].fill_diagonal_(True)
    
    pos_mask2 = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
    pos_mask2.fill_diagonal_(True)
    
    # all matches excluding the main diagonal
    logit_mask = torch.ones_like(pos_mask, device=device).fill_diagonal_(0)
    logit_mask_cat = torch.cat((logit_mask, logit_mask), dim = 1)

#     print (logits1_cat.shape, logit_mask_cat.shape)
    exp_logits1 = torch.exp(logits1_cat) * logit_mask_cat
    log_prob1 = logits1_cat - torch.log(exp_logits1.sum(1, keepdim=True))
    
    exp_logits2 = torch.exp(logits2_cat) * logit_mask_cat
    log_prob2 = logits2_cat - torch.log(exp_logits2.sum(1, keepdim=True))
    
    zeros_mask = torch.zeros_like(pos_mask)
    pos_mask = torch.cat((pos_mask, zeros_mask), dim = 1)
    pos_mask2 = torch.cat((pos_mask2, zeros_mask), dim = 1)

    # compute mean of log-likelihood over positives
    mean_log_prob_pos1 = (pos_mask * log_prob1).sum(1) / pos_mask.sum(1)
    # loss
    loss1 = -mean_log_prob_pos1.mean()
    
    # compute mean of log-likelihood over positives
    mean_log_prob_pos2 = (pos_mask2 * log_prob2).sum(1) / pos_mask.sum(1)
    # loss
    loss2 = -mean_log_prob_pos2.mean()
    
    loss = loss1 + loss2

    return loss

def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def barlow_dual_loss_func3(
    z1: torch.Tensor, z2: torch.Tensor, z1_prev: torch.Tensor, z2_prev: torch.Tensor, z1_distill: torch.Tensor, z2_distill: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025
) -> torch.Tensor:
    """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix.
            Defaults to 5e-3.
        scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025.

    Returns:
        torch.Tensor: Barlow Twins' loss.
    """
    
    # empirical cross-correlation matrix
    
    N, D = z1.size()
    bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device)
    
    z1 = bn(z1)
    z2 = bn(z2)
    z1_prev = bn(z1_prev)
    z2_prev = bn(z2_prev)
    z1_distill = bn(z1_distill)
    z2_distill = bn(z2_distill)
    
#     corr1 = torch.einsum("bi, bj -> ij", z1, z2) / N
#     corr1_prev = torch.einsum("bi, bj -> ij", z1, z1_prev) / N
    
#     corr2 = torch.einsum("bi, bj -> ij", z1_distill, z1_prev) / N
#     corr2_prev = torch.einsum("bi, bj -> ij", z1_distill, z1) / N
    
    corr1 = torch.einsum("bi, bj -> ij", z1, z2) / N
    corr1_prev = torch.einsum("bi, bj -> ij", z1, z2_prev) / N
    
    corr2 = torch.einsum("bi, bj -> ij", z1_distill, z1_prev) / N
    corr2_prev = torch.einsum("bi, bj -> ij", z1_distill, z2) / N
    
    if dist.is_available() and dist.is_initialized():
        dist.all_reduce(corr1)
        world_size = dist.get_world_size()
        corr1 /= world_size
        
        dist.all_reduce(corr2)
        corr2 /= world_size
        
        dist.all_reduce(corr1_prev)
        corr1_prev /= world_size
        
        dist.all_reduce(corr2_prev)
        corr2_prev /= world_size

    
    diag = torch.eye(D, device=corr1.device)
    cdif1 = (corr1 - diag).pow(2)
    cdif1[~diag.bool()] *= lamb
    off_diag_prev1 = off_diagonal(corr1_prev).pow_(2).sum()
    
    loss1 = scale_loss * (cdif1.sum() + off_diag_prev1 * lamb)
    
    
    diag = torch.eye(D, device=corr2.device)
    cdif2 = (corr2 - diag).pow(2)
    cdif2[~diag.bool()] *= lamb
    off_diag_prev2 = off_diagonal(corr2_prev).pow_(2).sum()
    
    loss2 = scale_loss * (cdif2.sum() + off_diag_prev2 * lamb)

    loss = (loss1 + loss2) 

    return loss
