import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

def logmeanexp(inputs):
    return inputs.max() + (inputs - inputs.max()).exp().mean().log()

class RenyiSCL(nn.Module):
    def __init__(self, backbone, alpha=1.0, gamma=2.0, dim=256, mlp_dim=4096, temp=0.5, n_cls=1000, K=16384):
        super(RenyiSCL, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.temp = temp
        self.n_cls = n_cls
        self.K = K
        # build encoders
        self.source_encoder = backbone(num_classes=mlp_dim)
        self._build_projector_and_predictor_mlps(dim, mlp_dim)

        if self.K > 0:
            self.register_buffer("queue1", torch.randn(dim, self.K))
            self.queue1 = nn.functional.normalize(self.queue1, dim=0)
            self.register_buffer("queue2", torch.randn(dim, self.K))
            self.queue2 = nn.functional.normalize(self.queue2, dim=0)
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

            self.register_buffer("queue_label", torch.randint(0, self.n_cls, (self.K,)))

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    @torch.no_grad()
    def _dequeue_and_enqueue(self, key1, key2, targets):
        key1 = concat_all_gather(key1)
        key2 = concat_all_gather(key2)
        targets = concat_all_gather(targets)
        
        batch_size = key1.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the key at ptr (dequeue and enqueue)
        self.queue1[:, ptr:ptr + batch_size] = key1.T
        self.queue2[:, ptr:ptr + batch_size] = key2.T
        self.queue_label[ptr:ptr + batch_size] = targets.clone().detach()
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    def contrastive_loss(self, q, k, y, queue=None, queue_label=None):
        k_all = concat_all_gather(k)
        logits = torch.mm(q, k_all.t()) / self.temp

        y_s = y.clone().detach()
        y_s = F.one_hot(y_s, num_classes=self.n_cls).float()
        y_t = y.clone().detach()
        y_t = concat_all_gather(y_t)
        y_t = F.one_hot(y_t, num_classes=self.n_cls).float()
        
        mask_pos = torch.mm(y_s, y_t.t())
        if queue is not None:
            logits_queue = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) / self.temp
            logits = torch.cat([logits, logits_queue], dim=1)

            y_queue = F.one_hot(queue_label.clone().detach(), num_classes=self.n_cls).float()
            mask_pos_queue = torch.mm(y_s, y_queue.t())
            mask_pos = torch.cat([mask_pos, mask_pos_queue], dim=1)

        mask_neg = torch.ones_like(mask_pos) - mask_pos

        pos = logits * mask_pos
        neg = logits * mask_neg

        if self.gamma == 1:
            loss_1 = -1 * pos.sum() / mask_pos.sum()
            loss_2 = torch.log(logits.exp().mean(dim=1)).mean()
        else:
            e_pos_1 = (((self.gamma - 1)* pos).exp() - mask_neg).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            loss_1 = - 1 * torch.log(e_pos_1).mean() / (self.gamma - 1)
            e_pos = ((self.gamma * pos).exp() - mask_neg).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            e_neg = ((self.gamma * neg).exp() - mask_pos).sum(dim=1, keepdim=True) / mask_neg.sum(dim=1, keepdim=True)
            loss_2 = torch.log(self.alpha * e_pos + (1 - self.alpha) * e_neg).mean() / self.gamma

        loss = loss_1 + loss_2
        return loss


    def forward(self, images, targets):
        f_s = self.source_encoder(images[0])
        f_t = self.source_encoder(images[1])
        f_s = nn.functional.normalize(f_s, dim=1)
        f_t = nn.functional.normalize(f_t, dim=1)

        if self.K > 0:
            loss = self.contrastive_loss(f_s, f_t, targets, self.queue1, self.queue_label) + \
                   self.contrastive_loss(f_t, f_s, targets, self.queue2, self.queue_label)
            self._dequeue_and_enqueue(f_t, f_s, targets)
        else:
            loss = self.contrastive_loss(f_s, f_t, targets)

        return loss


class RenyiSCL_ResNet(RenyiSCL):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        # projectors
        hidden_dim = self.source_encoder.fc.weight.shape[1]
        del self.source_encoder.fc
        # del self.target_encoder.fc # remove original fc layer
        
        self.source_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        # self.target_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        
        # predictor
        # self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False)



class RenyiSCL_MC(nn.Module):
    def __init__(self, backbone, alpha=1.0, gamma=2.0, dim=256, mlp_dim=4096, temp=0.5, n_cls=1000, K=16384):
        super(RenyiSCL_MC, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.temp = temp
        self.n_cls = n_cls
        self.K = K
        # build encoders
        self.source_encoder = backbone(num_classes=mlp_dim)
        self._build_projector_and_predictor_mlps(dim, mlp_dim)

        if self.K > 0:
            self.register_buffer("queue1", torch.randn(dim, self.K))
            self.queue1 = nn.functional.normalize(self.queue1, dim=0)
            self.register_buffer("queue2", torch.randn(dim, self.K))
            self.queue2 = nn.functional.normalize(self.queue2, dim=0)
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

            self.register_buffer("queue_label", torch.randint(0, self.n_cls, (self.K,)))

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    @torch.no_grad()
    def _dequeue_and_enqueue(self, key1, key2, targets):
        key1 = concat_all_gather(key1)
        key2 = concat_all_gather(key2)
        targets = concat_all_gather(targets)
        
        batch_size = key1.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the key at ptr (dequeue and enqueue)
        self.queue1[:, ptr:ptr + batch_size] = key1.T
        self.queue2[:, ptr:ptr + batch_size] = key2.T
        self.queue_label[ptr:ptr + batch_size] = targets.clone().detach()
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    def contrastive_loss(self, q, k, y, queue=None, queue_label=None):
        k_all = concat_all_gather(k)
        logits = torch.mm(q, k_all.t()) / self.temp

        y_s = y.clone().detach()
        y_s = F.one_hot(y_s, num_classes=self.n_cls).float()
        y_t = y.clone().detach()
        y_t = concat_all_gather(y_t)
        y_t = F.one_hot(y_t, num_classes=self.n_cls).float()
        
        mask_pos = torch.mm(y_s, y_t.t())
        if queue is not None:
            logits_queue = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) / self.temp
            logits = torch.cat([logits, logits_queue], dim=1)

            y_queue = F.one_hot(queue_label.clone().detach(), num_classes=self.n_cls).float()
            mask_pos_queue = torch.mm(y_s, y_queue.t())
            mask_pos = torch.cat([mask_pos, mask_pos_queue], dim=1)

        mask_neg = torch.ones_like(mask_pos) - mask_pos

        pos = logits * mask_pos
        neg = logits * mask_neg

        if self.gamma == 1:
            loss_1 = -1 * pos.sum() / mask_pos.sum()
            e_pos = (pos.exp() - mask_neg).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            e_neg = (neg.exp() - mask_pos).sum(dim=1, keepdim=True) / mask_neg.sum(dim=1, keepdim=True)
            loss_2 = torch.log(self.alpha * e_pos + (1-self.alpha)*e_neg).mean()
        else:
            e_pos_1 = (((self.gamma - 1)* pos).exp() - mask_neg).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            loss_1 = - 1 * torch.log(e_pos_1).mean() / (self.gamma - 1)
            e_pos = ((self.gamma * pos).exp() - mask_neg).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            e_neg = ((self.gamma * neg).exp() - mask_pos).sum(dim=1, keepdim=True) / mask_neg.sum(dim=1, keepdim=True)
            loss_2 = torch.log(self.alpha * e_pos + (1 - self.alpha) * e_neg).mean() / self.gamma

        loss = loss_1 + loss_2
        return loss


    def forward(self, images, targets):
        f_list = []
        for img in images:
            f = self.source_encoder(img)
            f = nn.functional.normalize(f, dim=1)
            f_list.append(f)

        loss = 0.0
        for idx_g in range(2):
            for idx_l in range(len(images)):
                if idx_g == idx_l:
                    continue
                if self.K > 0:
                    if idx_g == 0 :
                        loss += self.contrastive_loss(f_list[idx_l], f_list[idx_g], targets, self.queue1, self.queue_label)
                    else:
                        loss += self.contrastive_loss(f_list[idx_l], f_list[idx_g], targets, self.queue2, self.queue_label)   
                else:
                    loss += self.contrastive_loss(f_list[idx_l], f_list[idx_g], targets)
    
        if self.K > 0:
            self._dequeue_and_enqueue(f_list[0], f_list[1], targets)

        loss = loss / (len(images) - 1)

        return loss


class RenyiSCL_MC_ResNet(RenyiSCL_MC):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        # projectors
        hidden_dim = self.source_encoder.fc.weight.shape[1]
        del self.source_encoder.fc
        # del self.target_encoder.fc # remove original fc layer
        
        self.source_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        # self.target_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        

# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

class SLMLP(nn.Module):
    def __init__(self, backbone, temp=0.2, dim=256, mlp_dim=4096, n_cls=1000):
        super(SLMLP, self).__init__()
        self.temp = temp
        self.n_cls = n_cls
        # build encoders
        self.source_encoder = backbone(num_classes=mlp_dim)
        self._build_projector_and_predictor_mlps(dim, mlp_dim)
        self.last_layer = nn.utils.weight_norm(nn.Linear(dim, n_cls, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        self.last_layer.weight_g.requires_grad = False
        self.criterion = nn.CrossEntropyLoss()

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    def forward(self, images, targets):
        loss = 0.0
        for img in images:
            f = self.source_encoder(img)
            f = nn.functional.normalize(f, dim=1)
            output = self.last_layer(f) / self.temp
            loss += self.criterion(output, targets)
        loss /= len(images)
        return loss


class SLMLP_ResNet(SLMLP):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        # projectors
        hidden_dim = self.source_encoder.fc.weight.shape[1]
        del self.source_encoder.fc
        self.source_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
