import torch
import torch.nn as nn
from functools import partial
from torchvision.models import resnet
import torch.nn.functional as F
import pdb
import numpy as np
class SplitBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits

    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W)
            self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
            self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
            return outcome
        else:
            return nn.functional.batch_norm(
                input, self.running_mean, self.running_var,
                self.weight, self.bias, False, self.momentum, self.eps)

class MLPHead(nn.Module):
    def __init__(self, in_channels, mlp_hidden_size, projection_size):
        super(MLPHead, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_channels, mlp_hidden_size),
            # nn.BatchNorm1d(mlp_hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

class ModelBase(nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """
    def __init__(self, input_shape=[32, 32, 3], mlp_hidden_size=2048, feature_dim=128, arch=None, bn_splits=16):
        super(ModelBase, self).__init__() # arch = 'resnet50'
        # use split batchnorm
        norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        resnet_arch = getattr(resnet, arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)
        # pdb.set_trace()
        self.net = []
        for name, module in net.named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if isinstance(module, nn.MaxPool2d) or isinstance(module, nn.Linear):# or isinstance(module, nn.AdaptiveAvgPool2d):
                continue
            self.net.append(module)
        self.net = nn.Sequential(*self.net)
        self.projection_head = MLPHead(self._get_shape(input_shape), mlp_hidden_size, feature_dim)

    def _get_shape(self, shape):
        codes = torch.zeros(512, shape[2], shape[0], shape[1])#.to(self.device)
        feature = self.net(codes)
        dim = feature[0].view(-1).shape[0]
        return dim

    def forward(self, x):
        x = self.net(x)
        x = torch.flatten(x, 1)
        x = self.projection_head(x)
        return x


def accuracy(output, target):
    with torch.no_grad():
        _, pred = torch.max(output, 1)
        correct = pred.eq(target).cpu().sum().item()
        acc = 100. * correct / target.size(0)
        return acc


class ModelMoCo(nn.Module):
    def __init__(self, input_shape=[32,32,3], args=None):
        super(ModelMoCo, self).__init__()
        self.K = args.k
        self.m = args.m
        self.T = args.t
        self.symmetric = args.symmetric
        self.loss_type = args.loss_type
        self.teacher_weight = args.teacher_weight
        self.sharp = args.sharp_probability

        self.lam = args.lam
        self.teach_T = args.teach_T
        self.prior = 1.0

        self.batch_size = args.batch_size


        self.labels = torch.zeros(args.batch_size, dtype=torch.long).cuda()
        self.onehot_labels = torch.zeros(self.batch_size, self.K+1).cuda().scatter_(1, self.labels.view(-1, 1), 1)
        self.labels_mixup = torch.linspace(start=0, end=self.batch_size - 1, steps=self.batch_size).type(dtype=torch.long).cuda()
        self.onehot_labels_mixup = torch.zeros(self.batch_size, self.K + self.batch_size).cuda().scatter_(1, self.labels_mixup.view(-1, 1), 1)

        # create the encoders
        self.encoder_q = ModelBase(input_shape=input_shape, mlp_hidden_size=args.mlp_hidden_size, feature_dim=args.dim, arch=args.arch, bn_splits =args.bn_splits)
        self.encoder_k = ModelBase(input_shape=input_shape, mlp_hidden_size=args.mlp_hidden_size, feature_dim=args.dim, arch=args.arch, bn_splits=args.bn_splits)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.zeros(args.dim, self.K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))


    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        # pdb.set_trace()
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr


    @torch.no_grad()
    def _batch_shuffle_single_gpu(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # random shuffle index
        idx_shuffle = torch.randperm(x.shape[0]).cuda()

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        return x[idx_shuffle], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        return x[idx_unshuffle]

    def contrastive_loss(self, im_q, im_k):
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)

            k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized

            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # dot product
        # l_pos = torch.einsum('nc,ck->nk', [q, k.T])
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # dot product
        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)
        # apply temperature
        logits /= self.T

        # # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        loss = nn.CrossEntropyLoss().cuda()(logits, labels)
        acc = 0
        # pdb.set_trace()
        return loss, acc, q, k

    def get_momentum_feature(self, encoder_k, im_k):
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
            k = encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized
            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
        return k

    def get_class_probability(self, q, k):
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        logits = torch.cat([l_pos, l_neg], dim=1)
        return logits

    def get_class_probability_mixup(self, q, k):
        l_pos = torch.einsum('nc,ck->nk', [q, k.T])
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        logits = torch.cat([l_pos, l_neg], dim=1)
        return logits

    @torch.no_grad()
    def label_estimation(self, k ):
        ## estimate the current label
        logits_momentum = self.get_class_probability_mixup(k, k) / self.teach_T
        probs = F.softmax(logits_momentum, dim=1)

        ## sharpen
        probs = probs ** (1 / self.sharp)
        probs = probs / probs.sum(dim=1, keepdim=True)

        logits_momentum = logits_momentum.scatter_(1, self.labels_mixup.view(-1, 1), -float("Inf"))
        probs2 = F.softmax(logits_momentum, dim=1)
        probs2 = probs2 ** (1 / self.sharp)
        probs2 = probs2 / probs2.sum(dim=1, keepdim=True)


        guess_score, guess_label = torch.max(probs, 1)
        guess_score = guess_score.unsqueeze(1)

        guess_score2, _ = torch.max(probs2, 1)
        guess_score2 = guess_score2.unsqueeze(1)

        score_all = self.prior * guess_score2 + self.prior * guess_score + 1
        lambda1, lambda2 = self.prior * guess_score / score_all, self.prior * guess_score2 / score_all

        labels = probs * lambda1 + probs2 * lambda2 + self.onehot_labels_mixup * (1 - lambda1 - lambda2)
        acc = accuracy(labels, self.labels_mixup)
        return labels, acc

    def clearn_contrastive_loss(self, im_q, im_k):

        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # normalize
        k = self.get_momentum_feature(self.encoder_k, im_k)
        logits = self.get_class_probability_mixup(q, k)
        logits /= self.T

        ## vanilla contrastive loss
        loss = nn.CrossEntropyLoss().cuda()(logits, self.labels_mixup)

        ## estimate the current label
        labels, acc = self.label_estimation(k)

        # mixup
        with torch.no_grad():
            l = np.random.beta(self.lam, self.lam)
            idx = torch.randperm(self.batch_size)
            inputs_s = l * im_q + (1 - l) * im_k[idx]
            labels = l * labels + (1 - l) * labels[idx]

        ## mixup loss
        outputs = self.encoder_q(inputs_s)
        outputs = nn.functional.normalize(outputs, dim=1)  # already normalized
        outputs = self.get_class_probability_mixup(outputs, k) / self.T
        loss_teacher = -torch.sum((F.log_softmax(outputs, dim=1) * labels))


        # combine two loss
        loss = (1 - self.teacher_weight) * loss + self.teacher_weight * loss_teacher / self.batch_size

        return loss, acc,  q,  k

    def forward(self, weak1, weak2, prior=1.0):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
            prior: confiendence for label combination
        Output:
            loss
        """
        self.prior = prior
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()

        if self.loss_type == 'moco':
            if self.symmetric:  # asymmetric loss
                loss_12, acc1, q1, k2 = self.contrastive_loss(weak1, weak2)
                loss_21, acc2, q2, k1 = self.contrastive_loss(weak2, weak1)
                loss, acc = (loss_12 + loss_21) * 0.5, (acc1 + acc2) * 0.5
                self._dequeue_and_enqueue(k1)
                self._dequeue_and_enqueue(k2)
            else:  # asymmetric loss
                loss, acc, q, k = self.contrastive_loss(weak1, weak2)
                self._dequeue_and_enqueue(k)
        elif self.loss_type == 'CLEAN':
            if self.symmetric:
                loss_12, acc1, q1, k2 = self.clearn_contrastive_loss(weak1, weak2)
                loss_21, acc2, q2, k1 = self.clearn_contrastive_loss(weak2, weak1)
                loss, acc = (loss_12 + loss_21) * 0.5, (acc1 + acc2) * 0.5
                self._dequeue_and_enqueue(k1)
                self._dequeue_and_enqueue(k2)
            else:
                loss, acc, q, k = self.clearn_contrastive_loss(weak1, weak2)
                self._dequeue_and_enqueue(k)
        return loss, acc