import torch
import torch.nn.functional as F

class InfoNCELoss(torch.nn.Module):
    def __init__(self, n_views, temperature):
        super().__init__()
        self.n_views = n_views
        self.temperature = temperature

    def forward(self, features):
        batch_size = features.size(0) // 2
        device = features.device
        labels = torch.cat([torch.arange(batch_size) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        features = F.normalize(features, dim=1)
        # print(features.size(), features.T.size())
        
        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     n_views * batch_size, n_views * batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        # print(features.size(), labels.size(), mask.size())
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

        logits = logits / self.temperature
        loss_val = F.cross_entropy(logits, labels)
        return logits, labels, loss_val


class CoInfoNCELoss(torch.nn.Module):
    def __init__(self, n_views, temperature):
        super().__init__()
        self.n_views = n_views
        self.temperature = temperature

    def construct_logits(self, features):
        batch_size = features.size(0) // 2
        device = features.device
        labels = torch.cat([torch.arange(batch_size) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     n_views * batch_size, n_views * batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

        logits = logits / self.temperature
        return logits, labels
    
    def forward(self, target_feats, source_feats):
        source_logits, source_labels = self.construct_logits(source_feats)
        target_logits, target_labels = self.construct_logits(target_feats)
        # print(source_logits.size(), target_logits.size())
        # print(F.softmax(source_logits, dim=1)[:3])
        loss_val = - (F.softmax(source_logits, dim=1) * F.log_softmax(target_logits, dim=1)).sum() / source_logits.size(0)
        # return source_logits, source_labels, loss_val
        return source_logits, source_labels, target_logits, target_labels, loss_val


class CMDLoss(torch.nn.Module):
    def __init__(self, n_views, temperature, weights):
        super().__init__()
        self.info_nce_loss = InfoNCELoss(n_views, temperature)
        self.co_info_nce_loss = CoInfoNCELoss(n_views, temperature)
        self.weights = weights
        assert len(weights) == 3, 'There should be three weights for co_info/feat_mse/ssl_info.'

    def forward(self, target_feats, source_feats):
        # co info, feat mse, ssl_info
        source_logits, source_labels, target_logits, target_labels, co_info_loss_val = self.co_info_nce_loss(target_feats, source_feats)
        dist_loss_val = F.mse_loss(target_feats, source_feats, reduction='mean')
        target_logits, target_labels, ssl_loss_val = self.info_nce_loss(target_feats)
        # print('source target', source_feats[:2], target_feats[:2], ((target_feats - source_feats)**2).sum() / target_feats.size(0))
        return source_logits, source_labels, target_logits, target_labels, \
                (co_info_loss_val * self.weights[0]) + (dist_loss_val * self.weights[1]) + (ssl_loss_val * self.weights[2])
                #   dist_loss_val

class CLIPCMDLoss(torch.nn.Module):
    def __init__(self, n_views, temperature, weights):
        super().__init__()
        self.info_nce_loss = InfoNCELoss(n_views, temperature)
        self.co_info_nce_loss = CoInfoNCELoss(n_views, temperature)
        self.weights = weights
        self.logit_scale = torch.tensor(1 / temperature, requires_grad=False)
        assert len(weights) == 4, 'There should be four weights for co_info/feat_mse/ssl_info/clip.'
    
    def forward_clip(self, target_feats, source_feats):
        device = target_feats.device
        # normalized features
        source_feats = source_feats / source_feats.norm(dim=1, keepdim=True)
        target_feats = target_feats / target_feats.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale
        logits_per_source = logit_scale * source_feats @ target_feats.t()
        logits_per_target = logits_per_source.t()

        labels = torch.arange(logits_per_target.size(0)).to(device)
        loss_source = F.cross_entropy(logits_per_source, labels)
        loss_target = F.cross_entropy(logits_per_source, labels)
        loss_val = (loss_source + loss_target) / 2
        # shape = [global_batch_size, global_batch_size]
        
        return logits_per_source, logits_per_target, labels, loss_val

    def forward(self, target_feats, source_feats):
        # co info, feat mse, ssl_info 
        # CMD 
        source_logits, source_labels, target_logits, target_labels, co_info_loss_val = self.co_info_nce_loss(target_feats, source_feats)
        dist_loss_val = F.mse_loss(target_feats, source_feats, reduction='mean')
        target_logits, target_labels, ssl_loss_val = self.info_nce_loss(target_feats)

        # CLIP
        logits_per_source, logits_per_target, labels, clip_loss_val = self.forward_clip(target_feats, source_feats)

        # print('source target', source_feats[:2], target_feats[:2], ((target_feats - source_feats)**2).sum() / target_feats.size(0))
        return source_logits, source_labels, target_logits, target_labels, \
                (co_info_loss_val * self.weights[0]) + (dist_loss_val * self.weights[1]) + (ssl_loss_val * self.weights[2]) + \
                (clip_loss_val * self.weights[3])

def info_nce_loss(features, n_views, temperature):
    batch_size = features.size(0)
    device = features.device
    labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    # assert similarity_matrix.shape == (
    #     n_views * batch_size, n_views * batch_size)
    # assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

    logits = logits / temperature
    return logits, labels
