import torch
import torch.nn as nn
import torch.nn.functional as F

class SupConLoss(nn.Module): 
    def __init__(self, temperature=0.06, device="cuda:0"): 
        super().__init__()

        self.temperature = temperature
        self.device = device

    def forward(self, projection1, projection2, labels=None):
        
        projection1, projection2 = F.normalize(projection1), F.normalize(projection2)
        features = torch.cat([projection1.unsqueeze(1), projection2.unsqueeze(1)], dim=1)
        batch_size = features.shape[0]

        if labels is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(self.device)
        else:
            labels = labels.contiguous().view(-1, 1)
            mask = torch.eq(labels, labels.T).float().to(self.device) 

        contrast_count = features.shape[1]

        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        anchor_dot_contrast = torch.div(torch.matmul(contrast_feature, contrast_feature.T), self.temperature)

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach() # for numerical stability

        mask = mask.repeat(contrast_count, contrast_count)

        logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * contrast_count).view(-1, 1).to(self.device), 0)
        mask = mask * logits_mask

        # positive samples
        exp_logits = torch.exp(logits) * logits_mask 
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - mean_log_prob_pos
        loss = loss.view(contrast_count, batch_size).mean()

        loss_dict = {"supcon": loss}

        return loss, loss_dict

class SupConCELoss(nn.Module):
    def __init__(self, alpha=0.5, device="cuda:0", temperature=0.06, num_classes=10): 
        super().__init__()
        self.supcon = SupConLoss(temperature=temperature, device=device)
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha
        self.num_classes = num_classes

    def forward(self, projection1, projection2, prediction1, prediction2, target):

        predictions = torch.cat([prediction1, prediction2], dim=0)
        target_ = torch.nn.functional.one_hot(target, num_classes=self.num_classes).float()
        labels = torch.cat([target_, target_], dim=0)
        loss_scl = self.supcon(projection1, projection2, target)[0]
        loss_dict = {"supcon": loss_scl, "ce": self.ce(predictions, labels)}

        return self.alpha * loss_dict["supcon"] + (1 - self.alpha) * loss_dict["ce"], loss_dict
    

class AutoEncoderLoss(nn.Module):
    def __init__(self, alpha=0.5, device="cuda:0", temperature=0.06, num_classes=10):
        super().__init__()
        self.supcon = SupConLoss(temperature=temperature, device=device)
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha
        self.num_classes = num_classes

    def forward(self, projection1, projection2, prediction1, prediction2, target):

        predictions = torch.cat([prediction1, prediction2], dim=0)
        target_ = torch.nn.functional.one_hot(target, num_classes=self.num_classes).float()
        labels = torch.cat([target_, target_], dim=0)
        loss_scl = self.supcon(projection1, projection2, target)[0]
        loss_dict = {"supcon": loss_scl, "ce": self.ce(predictions, labels)}

        return self.alpha * loss_dict["supcon"] + (1 - self.alpha) * loss_dict["ce"], loss_dict
    
class ConLoss(nn.Module):
    def __init__(self, temperature=0.06, device="cuda:0"):
        super().__init__()
        self.temperature = temperature
        self.device = device

    def forward(self, projection1, projection2, labels=None):
        projection1, projection2 = F.normalize(projection1), F.normalize(projection2)
        features = torch.cat([projection1.unsqueeze(1), projection2.unsqueeze(1)], dim=1)
        batch_size = features.shape[0]
        if labels is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(self.device)
        else:
            labels = labels.contiguous().view(-1, 1)
            mask = torch.eq(labels, labels.T).float().to(self.device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        anchor_dot_contrast = torch.div(torch.matmul(contrast_feature, contrast_feature.T), self.temperature)

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        mask = mask.repeat(contrast_count, contrast_count)
        logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * contrast_count).view(-1, 1).to(self.device), 0)
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - mean_log_prob_pos
        loss = loss.view(contrast_count, batch_size).mean()

        loss_dict = {"con": loss}

        return loss, loss_dict

class CELoss(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.num_classes = num_classes

    def forward(self, prediction1, prediction2, target):
        predictions = torch.cat([prediction1, prediction2], dim=0)
        target_ = torch.nn.functional.one_hot(target, num_classes=self.num_classes).float()
        labels = torch.cat([target_, target_], dim=0)
        loss_dict = {"ce": self.ce(predictions, labels)}
        return loss_dict["ce"], loss_dict