import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import pdb


class NTXent(nn.Module):
    LARGE_NUMBER = 1e9

    def __init__(self, tau=1., gpu=None, multiplier=2):
        super().__init__()
        self.tau = tau
        self.multiplier = multiplier
        self.norm = 1.

    def forward(self, z, get_map=False):
        n = z.shape[0]
        assert n % self.multiplier == 0

        z = F.normalize(z, p=2, dim=1) / np.sqrt(self.tau)

        logits = z @ z.t()
        logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER

        logprob = F.log_softmax(logits, dim=1)

        # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
        m = self.multiplier
        labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n//m, n)) % n
        # remove labels pointet to itself, i.e. (i, i)
        labels = labels.reshape(n, m)[:, 1:].reshape(-1)

        # TODO: maybe different terms for each process should only be computed here...
        loss = -logprob[np.repeat(np.arange(n), m-1), labels].sum() / n / (m-1) / self.norm

        # zero the probability of identical pairs
#         pred = logprob.data.clone()
#         pred[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER
#         acc = accuracy_contrastive(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1)

#         if get_map:
#             _map = mean_average_precision(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1)
#             return loss, acc, _map

        return loss, 0, 0, 0

    
    
class UConLoss(nn.Module):
    LARGE_NUMBER = 1e9
    def __init__(self, temperature=0.07):
        super(UConLoss, self).__init__()
        self.temperature = temperature

    def stabilize(self, feature_matrix):
        # for numerical stability
        logits_max, _ = torch.max(feature_matrix, dim=1, keepdim=True)
        logits = feature_matrix - logits_max.detach()
        return logits

    def forward(self, features, batch_sizes, n_views, n_datasets, reg_aug = 1, reg_dataset = 1, reg_ood = 0):
        device = torch.device("cuda") if features.is_cuda else torch.device("cpu")
        
#         batch_size = features.shape[0] // n_views // n_datasets
        
        features = F.normalize(features, p = 2, dim = 1)
        
        # compute logits
        features_dot_features = torch.div(torch.matmul(features, features.T), self.temperature)
        
        loss1 = 0
        loss2 = 0
        start = 0
        for d in range(n_datasets):
#             start = d * batch_size * n_views
#             end = (d + 1) * batch_size * n_views
            end = start + (batch_sizes[d] * n_views)
    
            logits = features_dot_features[start : end][ : , start : end]
#             logits = self.stabilize(logits)
            logits[np.arange(batch_sizes[d] * n_views), np.arange(batch_sizes[d] * n_views)] = -self.LARGE_NUMBER

            mask = torch.eye(batch_sizes[d], dtype=torch.float32).to(device)
            aug_mask = mask.repeat(n_views, n_views)
            self_contrast_mask = (~torch.eye(batch_sizes[d] * n_views, dtype=torch.bool)).float().to(device)
            aug_mask = aug_mask * self_contrast_mask
            
            if reg_aug != 0:
                aug_loss = self.get_loss(logits, aug_mask)
                loss1 += aug_loss

            if reg_dataset != 0:
                dataset_loss = self.get_loss(logits, self_contrast_mask)
                loss2 += dataset_loss
                
            start = end
            
        loss1 = loss1/n_datasets
        loss2 = loss2/n_datasets
          
        loss3 = 0
        if reg_ood != 0:
#             logits = self.stabilize(features_dot_features)
            logits = features_dot_features
            logits[np.arange(features.shape[0]), np.arange(features.shape[0])] = -self.LARGE_NUMBER
            
            labels = [[d]*(batch_sizes[d] * n_views) for d in range(n_datasets)]
            labels = torch.LongTensor([i for l in labels for i in l]).view(-1, 1)
#             labels = torch.arange(n_datasets)
#             labels = labels.view(-1, 1).repeat(1, batch_sizes[0] * n_views).view(-1, 1)
            
            in_dist_mask = torch.eq(labels, labels.T).to(device)
            ood_mask = (~in_dist_mask).float()

            loss3 = self.get_loss(logits, ood_mask)
        
        # max aug_sim, max dataset_sim, min ood_sim
        
        maximize_loss = reg_aug * loss1 + reg_dataset * loss2 - reg_ood * loss3
        loss = -maximize_loss
        
        return loss, loss1, loss2, loss3
    
    def get_loss(self, logits, mask):
        exp_logits = torch.exp(logits) * mask
        sum_exp_logits = exp_logits.sum(1)
        loss = torch.log(sum_exp_logits)
        return loss.mean()
    
    
def accuracy_clustering(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    
class ClusteringLoss(nn.Module):
    def __init__(self, num_clusters, use_v2_loss = False, temperature = 0.1):
        super().__init__()
        self.num_clusters = num_clusters
        self.use_v2_loss = use_v2_loss
        self.temperature = temperature
       
    def forward(self, z, index, cluster_result):
        losses = 0
        accp = 0
        
#         index = index.repeat(2)
        z = F.normalize(z, dim = 1)
        
        if self.use_v2_loss:
#             proto_labels = []
#             proto_logits = []
            
#             pos_proto_id = cluster_result['im2cluster'][index]
#             logits_proto = torch.mm(z, cluster_result['centroids'].t()) / self.temperature
#             losses = nn.CrossEntropyLoss()(logits_proto, pos_proto_id)
#             accp = accuracy_clustering(logits_proto, pos_proto_id)[0] 

            cluster_id = cluster_result['im2cluster'][index]
            cluster_centroid = cluster_result['centroids'][cluster_id]    
            cluster_centroid = torch.cat([cluster_centroid, cluster_centroid], dim = 0)
            pos_assignment = torch.exp(torch.sum(z * cluster_centroid, dim = 1) / self.temperature)
            all_assignments = torch.sum(torch.exp(torch.matmul(z, cluster_result['centroids'].T) / self.temperature))
            loss = pos_assignment / all_assignments
            loss = -torch.log(loss)
            accp = 0
            losses = loss.mean()
                

        else:
#             for n, (im2cluster,prototypes,density) in enumerate(zip(cluster_result['im2cluster'],cluster_result['centroids'],cluster_result['density'])):
            for n, (im2cluster,prototypes) in enumerate(zip(cluster_result['im2cluster'],cluster_result['centroids'])):
                # get positive prototypes
                cluster_id = im2cluster[index]
                cluster_centroid = prototypes[cluster_id]    
                cluster_centroid = torch.cat([cluster_centroid, cluster_centroid], dim = 0)
                loss = torch.norm(z - cluster_centroid, dim = 1) ** 2
                losses += loss.mean()
        
        return losses, accp