# Thank the authors of Co^2L and SimCLRV2.
# The github address is https://github.com/chaht01/Co2L
# and https://github.com/google-research/simclr respectively.
# Our code is widely adapted from their repositories.

import torch
import torch.nn as nn

class SupConLoss(nn.Module):
    def __init__(self, base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.base_temperature = base_temperature

    def forward(self, args, features, labels = None, current_task_classes = None):

        batch_size = labels.shape[0]
        f1, f2 = torch.split(features, [batch_size, batch_size], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        features = features.view(features.shape[0], features.shape[1], -1)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(args.device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        anchor_dot_contrast = torch.div(
            torch.matmul(contrast_feature, contrast_feature.T),
            args.temperature)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        mask = mask.repeat(contrast_count, contrast_count)

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * contrast_count).view(-1, 1).to(args.device),
            0
        )
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        loss = -mean_log_prob_pos

        curr_class_mask = torch.zeros_like(labels)
        for current_task_class in current_task_classes:
            curr_class_mask += (labels == current_task_class)
        curr_class_mask = curr_class_mask.view(-1).to(args.device)
        loss = curr_class_mask * loss.view(contrast_count, batch_size)
        loss = loss.mean()

        return loss


class UnsupConLoss(nn.Module):
    def __init__(self):
        super(UnsupConLoss, self).__init__()

    def forward(self, args, features, labels = None):
        
        batch_size = labels.shape[0]
        f1, f2 = torch.split(features, [batch_size, batch_size], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        features = features.view(features.shape[0], features.shape[1], -1)
        labels = labels.contiguous().view(-1, 1)
        mask = 1 - torch.eq(labels, labels.T).float().to(args.device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        N = contrast_count * batch_size

        anchor_dot_contrast = torch.div(
            torch.matmul(contrast_feature, contrast_feature.T),
            args.temperature)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        mask = mask.repeat(contrast_count, contrast_count)

        sim_i_j = torch.diag(logits, batch_size)
        sim_j_i = torch.diag(logits, -batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        positive_samples = torch.exp(positive_samples)
        negative_samples = torch.exp(logits)

        loss = -torch.log(positive_samples / (positive_samples + torch.sum(negative_samples * mask, axis = 1)))

        return torch.mean(loss)


class IRDLoss(nn.Module):
    def __init__(self):
        super(IRDLoss, self).__init__()

    def forward(self, args, features, model2, images):

        current_model_features = features

        current_model_similarities = torch.div(torch.matmul(current_model_features, current_model_features.T), args.current_model_temp)
        logits_mask = torch.scatter(
            torch.ones_like(current_model_similarities),
            1,
            torch.arange(current_model_similarities.size(0)).view(-1, 1).cuda(non_blocking=True),
            0
        )
        logits_max1, _ = torch.max(current_model_similarities * logits_mask, dim=1, keepdim=True)
        current_model_similarities = current_model_similarities - logits_max1.detach()
        row_size = current_model_similarities.size(0)
        logits1 = torch.exp(current_model_similarities[logits_mask.bool()].view(row_size, -1)) / torch.exp(current_model_similarities[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)

        with torch.no_grad():
            prev_model_features = model2(images)
            prev_model_similarities = torch.div(torch.matmul(prev_model_features, prev_model_features.T), args.prev_model_temp)
            logits_max2, _ = torch.max(prev_model_similarities * logits_mask, dim=1, keepdim = True)
            prev_model_similarities = prev_model_similarities - logits_max2.detach()
            logits2 = torch.exp(prev_model_similarities[logits_mask.bool()].view(row_size, -1)) /  torch.exp(prev_model_similarities[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)

        loss = (-logits2 * torch.log(logits1)).sum(1).mean()
        
        return loss


class NT_Xent(nn.Module):
    def __init__(self):
        super(NT_Xent, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, args, z):
        
        batch_size = int(z.shape[0] / 2)
        N = 2 * batch_size
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / args.unlabeled_temperature

        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask_correlated_samples(batch_size)].reshape(N, -1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss


class URSLLoss(nn.Module):
    def __init__(self):
        super(URSLLoss, self).__init__()
        self.sup_con_loss = SupConLoss()
        self.ird_loss = IRDLoss()
        self.unsup_con_loss = NT_Xent()

    def forward(self, args, images, labeled_features, labels, model2, current_task_classes, task_number):
        ird_loss = 0.
        sup_con_loss = 0.

        sup_con_loss = self.sup_con_loss(args, labeled_features, labels, current_task_classes)
        if (task_number) > 0:
            ird_loss = self.ird_loss(args, labeled_features, model2, images) * args.td_coeff

        losses = {
            'sup_con_loss':sup_con_loss,
            'ird_loss': ird_loss,
            'unsup_con_loss' : 0.,
            'total': sup_con_loss + ird_loss
        }
        
        return losses