import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from functools import partial
import numpy as np
import torchvision.transforms as transforms
import random
import pdb
def accuracy(output, target):
    with torch.no_grad():
        _, pred = torch.max(output, 1)
        correct = pred.eq(target).cpu().sum().item()
        acc = 100. * correct / target.size(0)
        return acc

class CLEAN(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, args, master_flag=False, logger=None):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(CLEAN, self).__init__()

        self.dim, self.K, self.T, self.m, = args.moco_dim, args.moco_k, args.moco_t, args.moco_m
        self.teach_t, self.sharp,  = args.teach_t, args.sharp_probability

        self.lambda_weight, self.lam = args.teacher_weight, args.lam
        self.loss_type = args.loss_type
        self.sys = args.sys

        self.warmup_epoch = args.warmup_epoch
        self.prior = 1.0

        self.strong_crop_num = args.strong_crop_num
        self.weak_crop_num = args.weak_crop_num

        self.master_flag, self.logger = master_flag, logger

        self.criterion = nn.CrossEntropyLoss().cuda()

        self.register_buffer("labels_mixup_this", torch.linspace(start=0, end=args.batch_size - 1, steps=args.batch_size).type(dtype=torch.long))
        self.register_buffer("onehot_labels_mixup_this", torch.zeros(args.batch_size, self.K + args.batch_size).scatter_(1, self.labels_mixup_this.view(-1,1),1))

        if self.master_flag:
            parameter = 'dim%d_K%d_T%.4f_m%.4f_teachT%.4f_sharp%.4f_weight%.4f_lam%.4f_type%s_warm%d'% ( \
                    self.dim, self.K, self.T, self.m, self.teach_t, self.sharp,\
                    self.lambda_weight, self.lam, self.loss_type, self.warmup_epoch)
            self.logger.info('--------------------------------------inner parameters-------------------------------------------')
            self.logger.info(parameter)
            self.logger.info(self.labels_mixup_this)
            _, predic = torch.max(self.onehot_labels_mixup_this,dim=1)
            self.logger.info(predic)
            self.logger.info('--------------------------------------inner parameters-------------------------------------------')

        self.softmax = nn.Softmax(dim=1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=self.dim)
        self.encoder_k = base_encoder(num_classes=self.dim)


        dim_mlp = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
        self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(self.dim, self.K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr


    @torch.no_grad()
    def get_momentum_feature(self, im_k):
        im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
        k = self.encoder_k(im_k)  # keys: NxC
        k = F.normalize(k, dim=1)
        # undo shuffle
        k = self._batch_unshuffle_ddp(k, idx_unshuffle)
        return k

    def get_class_probability(self, q, k):
        # positive logits: Nx1
        l_pos = torch.einsum('nc,ck->nk', [q, k.T])
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # dot product
        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)
        return logits

    @torch.no_grad()
    def get_mixup_label(self, local_labels):
        beta = np.random.beta(self.lam, self.lam)
        idx = torch.randperm(local_labels.shape[0])
        beta = max(beta, 1 - beta)
        labels = beta * local_labels + (1 - beta) * local_labels[idx]
        acc3 = accuracy(labels, self.labels_mixup_this)
        return labels, idx, beta, acc3


    @torch.no_grad()
    def get_mixup_feature(self, x, y, idx, beta):
        z = beta * x + (1 - beta) * y[idx]
        return z

    @torch.no_grad()
    def label_estimation(self, k, k_all, one_hot_label):
        ## step 1 estimate the distrubition of keys on the dictionary [k, queue], and also make probability sharp
        logits_momentum = self.get_class_probability(k, k_all) / self.teach_t
        probs = self.softmax(logits_momentum)
        probs = probs ** (1 / self.sharp)
        probs = probs / probs.sum(dim=1, keepdim=True)

        ## Step 2 estimiate the distrubition when removing the sample itself in dic, which can increase the probabily on the remaining samples in dic
        logits_momentum = logits_momentum.scatter_(1, self.labels_mixup_this.view(-1, 1), -float("Inf"))
        probs2 = self.softmax(logits_momentum)
        probs2 = probs2 ** (1 / self.sharp)
        probs2 = probs2 / probs2.sum(dim=1, keepdim=True)

        ## Step 3 combine onehot lable, probs, probs2 to obtain new label
        acc = accuracy(probs, self.labels_mixup_this)
        confidence_score, confidence_label = torch.max(probs, 1)
        confidence_score = confidence_score.unsqueeze(1)

        confidence_score2, _ = torch.max(probs2, 1)
        confidence_score2 = confidence_score2.unsqueeze(1)
        idx = confidence_score2 > confidence_score
        idx2 = torch.nonzero(idx, as_tuple=False)
        confidence_score2[idx2] = confidence_score[idx2]

        score_all = self.prior * (confidence_score2 + confidence_score) + 1
        lambda1, lambda2 = self.prior * confidence_score / score_all, self.prior * confidence_score2 / score_all

        local_labels = probs * lambda1 + probs2 * lambda2 + one_hot_label * (1 - lambda1 - lambda2)
        confidence_score_mean, confidence2_score_mean = lambda1.mean(), lambda2.mean()

        acc2 = accuracy(local_labels, self.labels_mixup_this)
        return local_labels, acc, acc2, confidence_score_mean, confidence2_score_mean


    def get_mixup_loss(self, x, label, k_all):
        outputs = self.encoder_q(x)
        outputs = F.normalize(outputs, dim=1)  # already normalized
        outputs = self.get_class_probability(outputs, k_all)
        loss = - torch.mean((self.log_softmax(outputs / self.T) * label)) * outputs.shape[1]
        return loss


    def get_constrast_loss_single(self, qs_w, im_k, k, local_label, idx_shuffle, beta):
        im_shape = im_k.shape[3]
        if im_shape != qs_w.shape[3]:  ## not same size
            with torch.no_grad():
                im_k_i = F.interpolate(im_k, size=qs_w.shape[3])
        else:
            im_k_i = im_k
        qw = self.get_mixup_feature(qs_w, im_k_i, idx_shuffle, beta)
        loss_qs_w = self.get_mixup_loss(qw, local_label, k)
        return loss_qs_w


    def contrast_loss_mixup_label(self, im_q, im_k, qs):
        q = self.encoder_q(im_q)  # queries: NxC
        q = F.normalize(q, dim=1)
        k = self.get_momentum_feature(im_k)

        # Step 1. compute vanilla contrast loss
        l_pos = torch.einsum('nc,ck->nk', [q, k.T])  # dot product
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # dot product
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T
        con_loss = self.criterion(logits, self.labels_mixup_this)
        test_acc = accuracy(logits, self.labels_mixup_this)  ## prediction accuracy

        # Step 3. shuffer labels and produce the shaffer orders
        local_labels, acc, acc2, confidence_score_mean, confidence2_score_mean = self.label_estimation(k, k, self.onehot_labels_mixup_this)

        if random.uniform(0, 1) < 0.5:
            local_label, idx_shuffle, beta, acc3 = self.get_mixup_label(local_labels)
            ## Step 4. weak + weak
            if isinstance(qs, list):
                loss_qs = 0
                for i in range(len(qs)):
                    loss_qs += self.get_constrast_loss_single(qs[i], im_k, k, local_label, idx_shuffle, beta)
                loss_qs /= len(qs)
            else:
                loss_qs = self.get_constrast_loss_single(qs, im_k, k, local_label, idx_shuffle, beta)
        else:
            if isinstance(qs, list):
                loss_qs = 0
                for i in range(len(qs)):
                    loss_qs += self.get_single_loss_new(qs[i], k, local_labels)
                loss_qs /= len(qs)
            else:
                loss_qs = self.get_single_loss_new(qs, k, local_labels)
            acc3 = 0
        ## Step 5. all loss
        loss = (1 - self.lambda_weight) * con_loss + self.lambda_weight * loss_qs
        # if self.master_flag:
        #     self.logger.info('loss %.4f loss 1 %.4f loss 2 %.4f '%(loss,  con_loss, loss_qs))
        return loss, test_acc, acc, acc2, acc3, confidence_score_mean, confidence2_score_mean, k

    def get_single_loss_new(self, qs, k, local_labels):
        qss = self.encoder_q(qs)  # queries: NxC
        qss = F.normalize(qss, dim=1)
        # Step 1. compute vanilla contrast loss
        logits_qs = self.get_class_probability(qss, k) / self.T
        loss_q = - torch.mean((self.log_softmax(logits_qs) * local_labels)) * local_labels.shape[1]
        return loss_q


    def get_constrast_loss_single(self, qs_w, im_k, k, local_label, idx_shuffle, beta):
        im_shape = im_k.shape[3]
        if im_shape != qs_w.shape[3]:  ## not same size
            with torch.no_grad():
                im_k_i = F.interpolate(im_k, size=qs_w.shape[3])
        else:
            im_k_i = im_k
        qw = self.get_mixup_feature(qs_w, im_k_i, idx_shuffle, beta)

        loss_qs_w = self.get_mixup_loss(qw, local_label, k)
        return loss_qs_w


    def contrast_loss(self, im_q, im_k):
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = F.normalize(q, dim=1)
        k = self.get_momentum_feature(im_k)

        # logits = self.get_class_probability(q, k)
        l_pos = torch.einsum('nc,ck->nk', [q, k.T])  # dot product
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # dot product
        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        logits /= self.T
        loss = self.criterion(logits, self.labels_mixup_this)
        acc = accuracy(logits, self.labels_mixup_this)

        self._dequeue_and_enqueue(k)
        return loss, acc


    def forward(self, q_w, k_w, qs_s, prior):
        self.prior = prior

        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

        if self.loss_type == 'moco':
            loss, test_acc = self.contrast_loss(q_w, k_w)
            acc, acc2, acc3, confidence_score_mean, confidence2_score_mean = 0, 0, 0, 0, 0
        elif self.loss_type == 'CLEAN':
            loss, test_acc, acc, acc2, acc3, confidence_score_mean, confidence2_score_mean, k1 = self.contrast_loss_mixup_label(q_w, k_w, qs_s)
            if self.sys:
                loss2, test_acc2, acc21, acc22, acc32, confidence_score_mean2, confidence2_score_mean2, k2 = self.contrast_loss_mixup_label(k_w, q_w, qs_s)
                loss, test_acc, acc = 0.5*(loss+loss2), 0.5*(test_acc+test_acc2), 0.5*(acc+acc21)
                acc2, acc3 = 0.5*(acc2+acc22), 0.5*(acc3+acc32)
                confidence_score_mean, confidence2_score_mean = 0.5*(confidence_score_mean+confidence_score_mean2), 0.5*(confidence2_score_mean+confidence2_score_mean2)
            self._dequeue_and_enqueue(k1)
            if self.sys:
                self._dequeue_and_enqueue(k2)

        return loss, test_acc, acc, acc2, acc3, confidence_score_mean, confidence2_score_mean


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output
