import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import einsum
from einops import rearrange

import numpy as np


class Embed(nn.Module):
    def __init__(self, in_channels=512, out_channels=128):
        super(Embed, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.l2norm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv2d(x)
        x = self.l2norm(x)
        return x
    
class AttnEmbed(nn.Module):
    def __init__(self, in_channels=512, out_channels=128):
        super(AttnEmbed, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.l2norm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv2d(x)
        x = self.l2norm(x)
        return x

class TaTLoss(nn.Module):
    """ 
        Compute the L2 loss.
        Don't put any learnable parameters in this py file
    """
    def __init__(self, heads=1, channels=1920, factor=0.2, use_pool=False):
        super().__init__()
        self.heads=heads
        self.use_pool = use_pool
        self.q_proj = nn.Identity()
        self.k_proj = AttnEmbed(channels, channels)
        self.v_proj = AttnEmbed(channels, channels)
        self.factor = factor
        self.scale = channels ** -0.5

        if use_pool:
            self.pool = nn.AvgPool2d(2)

    def batch_loss(self, f_s, f_t):

        s, q, v = f_s # student feature, query, value
        t, k    = f_t # teacher feature, key
        heads = self.heads # multi head, heads*c = d
        b,c,h,w = v.shape
        
        q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)

        sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw)
        sim *= self.factor
        sim = sim.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d)

        out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        q   = rearrange(q,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        k   = rearrange(k,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)

        #--------------1. Feature Matching ---------------------------#
        loss = nn.MSELoss()(out,t) 

        return loss
    
    def forward(self, feat_S, feat_T):
        if self.use_pool:
            feat_S = self.pool(feat_S)
            feat_T = self.pool(feat_T)
        q = self.q_proj(feat_T)
        v = self.v_proj(feat_S)
        k = self.k_proj(feat_S)
        chsim_loss = 0
        chsim_loss += self.batch_loss((feat_S, q, v), (feat_T, k)) * self.factor
        return chsim_loss
    

class NewTaTLoss(nn.Module):
    """ 
        Compute the L2 loss.
        Don't put any learnable parameters in this py file
    """
    def __init__(self, heads=1, channels=1920, factor=0.2, use_pool=False):
        super().__init__()
        self.heads=heads
        self.use_pool = use_pool
        self.q_proj = nn.Identity()
        self.k_proj = Embed(channels, channels)
        self.v_proj = Embed(channels, channels)
        self.factor = factor

        if use_pool:
            self.pool = nn.AvgPool2d(2)

    def batch_loss(self, f_s, f_t):

        s, q, v = f_s # student feature, query, value
        t, k    = f_t # teacher feature, key
        heads = self.heads # multi head, heads*c = d
        b,c,h,w = v.shape
        
        q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)

        sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw)
        sim = sim.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d)

        out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        q   = rearrange(q,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        k   = rearrange(k,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)

        #--------------1. Feature Matching ---------------------------#
        loss = nn.MSELoss()(out,t) 

        return loss
    
    def forward(self, feat_S, feat_T):
        if self.use_pool:
            feat_S = self.pool(feat_S)
            feat_T = self.pool(feat_T)
        q = self.q_proj(feat_T)
        v = self.v_proj(feat_S)
        k = self.k_proj(feat_S)
        chsim_loss = 0
        chsim_loss += self.batch_loss((feat_S, q, v), (feat_T, k)) * self.factor
        return chsim_loss


class MGDLoss(nn.Module):
    """PyTorch version of `Masked Generative Distillation`
   
    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map. 
        name (str): the loss name of the layer
        alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00007
        lambda_mgd (float, optional): masked ratio. Defaults to 0.5
    """
    def __init__(self,
                 student_channels,
                 teacher_channels,
                 alpha_mgd=0.00007,
                 lambda_mgd=0.15,
                 ):
        super(MGDLoss, self).__init__()
        self.alpha_mgd = alpha_mgd
        self.lambda_mgd = lambda_mgd
    
        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.align = None

        '''self.generation = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True), 
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))'''
        self.generation = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True), 
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=1, stride=1, padding=0))


    def forward(self,
                preds_S,
                preds_T):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*C*H*W, student's feature map
            preds_T(Tensor): Bs*C*H*W, teacher's feature map
        """
        assert preds_S.shape[-2:] == preds_T.shape[-2:]

        if self.align is not None:
            preds_S = self.align(preds_S)
    
        loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgd
            
        return loss

    def get_dis_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='mean')
        N, C, H, W = preds_T.shape

        device = preds_S.device
        mat = torch.rand((N,C,1,1)).to(device)
        # mat = torch.rand((N,1,H,W)).to(device)
        mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device)

        masked_fea = torch.mul(preds_S, mat)
        new_fea = self.generation(masked_fea)

        dis_loss = loss_mse(new_fea, preds_T)/N

        # dis_loss = F.kl_div(F.log_softmax(torch.flatten(new_fea, 1), dim=-1), F.softmax(torch.flatten(preds_T, 1) / 0.9, dim=-1), reduction='batchmean')

        return dis_loss
    

class ICKDLoss(nn.Module):
    """Inter-Channel Correlation"""
    def __init__(self,
                 student_channels,
                 teacher_channels,):
        super(ICKDLoss, self).__init__()
        self.embed_s = Embed(student_channels, teacher_channels)
    
    def forward(self, g_s, g_t):
        loss =  self.batch_loss(g_s, g_t)
        return loss
        
    def batch_loss(self, f_s, f_t):
        f_s = self.embed_s(f_s)
        bsz, ch, h, w = f_s.shape
        
        f_s = f_s.view(bsz, ch, -1)
        f_t = f_t.view(bsz, ch, -1)
        
        emd_s = torch.bmm(f_s, f_s.permute(0,2,1))
        emd_s = torch.nn.functional.normalize(emd_s, dim = 2)

        emd_t = torch.bmm(f_t, f_t.permute(0,2,1)) 
        emd_t = torch.nn.functional.normalize(emd_t, dim = 2)
        
        G_diff = emd_s - emd_t
        loss = (G_diff * G_diff).view(bsz, -1).sum() / (h*w*bsz)
        return loss
    

class TatWithMiniBatch(nn.Module):
    def __init__(self, temperature=1.0, factor=0.001, channels=1920, heads=1):
        super(TatWithMiniBatch, self).__init__()
        self.temperature = temperature
        self.heads=heads
        self.q_proj = nn.Identity()
        self.k_proj = Embed(channels, channels)
        self.v_proj = Embed(channels, channels)
        self.factor = factor
    
    def pair_wise_sim_map_speed(self, fea_0, fea_1):
        B, C, H, W = fea_0.size()

        fea_0 = fea_0.reshape(B, C, -1).transpose(1, 2)
        fea_1 = fea_1.reshape(B, C, -1)
        
        sim_map_0_1 = torch.matmul(fea_0, fea_1)
        sim_map_0_1 = torch.einsum('bic,dcj->bdij', fea_0, fea_1)
        return sim_map_0_1.reshape(-1, sim_map_0_1.shape[-1])

    def batch_loss(self, f_s, f_t):

        s, q, v = f_s # student feature, query, value
        t, k    = f_t # teacher feature, key
        heads = self.heads # multi head, heads*c = d
        b,c,h,w = v.shape
        
        q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)

        sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw)
        sim = sim.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d)

        out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)

        #--------------1. Feature Matching ---------------------------#
        loss = nn.MSELoss()(out,t) 

        return loss

    def forward(self, feat_S, feat_T):
        B, C, H, W = feat_S.size()
        
        q = self.q_proj(feat_T)
        v = self.v_proj(feat_S)
        k = self.k_proj(feat_S)
        chsim_loss = 0
        chsim_loss += self.batch_loss((feat_S, q, v), (feat_T, k)) * self.factor

        feat_S = F.normalize(feat_S, p=2, dim=1)
        feat_T = F.normalize(feat_T, p=2, dim=1)
        
        sim_dis = torch.tensor(0.).cuda()
        s_sim_map = self.pair_wise_sim_map_speed(feat_S, feat_S)
        t_sim_map = self.pair_wise_sim_map_speed(feat_T, feat_T)

        p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
        p_t = F.softmax(t_sim_map / self.temperature, dim=1)

        sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
        sim_dis += sim_dis_

        loss = sim_dis + chsim_loss
     
        return loss


class SelfAttention(nn.Module):
    def __init__(self, channels=1920, heads=1):
        super(SelfAttention, self).__init__()
        self.q_proj = Embed(channels, channels)
        self.k_proj = Embed(channels, channels)
        self.v_proj = Embed(channels, channels)
        self.heads=heads

    def forward(self, x):
        heads = self.heads # multi head, heads*c = d
        b,c,h,w = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)

        sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw)
        sim = sim.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d)

        out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        return x


class Tat4MiniBatch(nn.Module):
    def __init__(self, temperature=1.0):
        super(Tat4MiniBatch, self).__init__()
        self.temperature = temperature
        self.att_s = SelfAttention()
        self.att_t = SelfAttention()
    
    def pair_wise_sim_map_speed(self, fea_0, fea_1):
        B, C, H, W = fea_0.size()

        fea_0 = fea_0.reshape(B, C, -1).transpose(1, 2)
        fea_1 = fea_1.reshape(B, C, -1)
        
        # sim_map_0_1 = torch.matmul(fea_0, fea_1)
        sim_map_0_1 = torch.einsum('bic,dcj->bdij', fea_0, fea_1)
        return sim_map_0_1.reshape(-1, sim_map_0_1.shape[-1])

    def forward(self, feat_S, feat_T):
        B, C, H, W = feat_S.size()
        
        feat_S = self.att_s(feat_S)
        feat_T = self.att_t(feat_T)

        feat_S = F.normalize(feat_S, p=2, dim=1)
        feat_T = F.normalize(feat_T, p=2, dim=1)
        
        sim_dis = torch.tensor(0.).cuda()
        s_sim_map = self.pair_wise_sim_map_speed(feat_S, feat_S)
        t_sim_map = self.pair_wise_sim_map_speed(feat_T, feat_T)

        p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
        p_t = F.softmax(t_sim_map / self.temperature, dim=1)

        sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
        sim_dis += sim_dis_

        loss = sim_dis
     
        return loss

class IntraImageMiniBatch(nn.Module):
    def __init__(self, temperature=1.0):
        super(IntraImageMiniBatch, self).__init__()
        self.temperature = temperature
    
    def pair_wise_sim_map_speed(self, fea_0, fea_1):
        B, C, H, W = fea_0.size()

        fea_0 = fea_0.reshape(B, C, -1).transpose(1, 2)
        fea_1 = fea_1.reshape(B, C, -1)
        
        sim_map = torch.bmm(fea_0, fea_1)
        return sim_map.reshape(-1, sim_map.shape[-1])


    def forward(self, feat_S, feat_T, type='fast'):
        B, C, H, W = feat_S.size()
        
        feat_S = F.normalize(feat_S, p=2, dim=1)
        feat_T = F.normalize(feat_T, p=2, dim=1)
        
        sim_dis = torch.tensor(0.).cuda()
        s_sim_map = self.pair_wise_sim_map_speed(feat_S, feat_S)
        t_sim_map = self.pair_wise_sim_map_speed(feat_T, feat_T)

        p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
        p_t = F.softmax(t_sim_map / self.temperature, dim=1)

        sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
        sim_dis += sim_dis_
        
        return sim_dis
    
class CriterionMiniBatchCrossImagePair(nn.Module):
    def __init__(self, temperature, pooling=False, factor=100.0):
        super(CriterionMiniBatchCrossImagePair, self).__init__()
        self.temperature = temperature
        self.pooling = pooling
        self.factor = factor

    def pair_wise_sim_map(self, fea_0, fea_1):
        C, H, W = fea_0.size()

        fea_0 = fea_0.reshape(C, -1).transpose(0, 1)
        fea_1 = fea_1.reshape(C, -1).transpose(0, 1)
        
        sim_map_0_1 = torch.mm(fea_0, fea_1.transpose(0, 1))
        return sim_map_0_1
    
    def pair_wise_sim_map_speed(self, fea_0, fea_1):
        B, C, H, W = fea_0.size()

        fea_0 = fea_0.reshape(B, C, -1).transpose(1, 2)
        fea_1 = fea_1.reshape(B, C, -1)
        
        sim_map_0_1 = torch.matmul(fea_0, fea_1)
        sim_map_0_1 = torch.einsum('bic,dcj->bdij', fea_0, fea_1)
        return sim_map_0_1.reshape(-1, sim_map_0_1.shape[-1])


    def forward(self, feat_S, feat_T, type='fast'):
        B, C, H, W = feat_S.size()

        if self.pooling:
            avg_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=True)
            feat_S_pool = avg_pool(feat_S)
            feat_T_pool = avg_pool(feat_T)
            feat_S_pool = F.normalize(feat_S_pool, p=2, dim=1)
            feat_T_pool = F.normalize(feat_T_pool, p=2, dim=1)
        
        feat_S = F.normalize(feat_S, p=2, dim=1)
        feat_T = F.normalize(feat_T, p=2, dim=1)
        
        sim_dis = torch.tensor(0.).cuda()
        if type == 'fast':
            s_sim_map = self.pair_wise_sim_map_speed(feat_S, feat_S)
            t_sim_map = self.pair_wise_sim_map_speed(feat_T, feat_T)

            p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
            p_t = F.softmax(t_sim_map / self.temperature, dim=1)

            sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
            sim_dis += sim_dis_
            if self.pooling:
                s_sim_map = self.pair_wise_sim_map_speed(feat_S_pool, feat_S_pool)
                t_sim_map = self.pair_wise_sim_map_speed(feat_T_pool, feat_T_pool)

                p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
                p_t = F.softmax(t_sim_map / self.temperature, dim=1)

                sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
                sim_dis += sim_dis_
        else:
            for i in range(B):
                for j in range(B):
                    s_sim_map = self.pair_wise_sim_map(feat_S[i], feat_S[j])
                    t_sim_map = self.pair_wise_sim_map(feat_T[i], feat_T[j])

                    p_s = F.log_softmax(s_sim_map / self.temperature, dim=1)
                    p_t = F.softmax(t_sim_map / self.temperature, dim=1)

                    sim_dis_ = F.kl_div(p_s, p_t, reduction='batchmean')
                    sim_dis += sim_dis_
            sim_dis = sim_dis / (B * B)
        return sim_dis * self.factor

class MemoryBasedCrossImagePair(nn.Module):
    def __init__(self, pixel_memory_size=20000, pixel_contrast_size=4096, contrast_kd_temperature=1.0, contrast_temperature=1.0, dim=1024):
        super(MemoryBasedCrossImagePair, self).__init__()
        self.contrast_kd_temperature = contrast_kd_temperature
        self.contrast_temperature = contrast_temperature
        
        self.pixel_memory_size = pixel_memory_size
        self.pixel_update_freq = 128
        self.pixel_contrast_size = pixel_contrast_size
        self.dim = dim

        self.register_buffer("teacher_pixel_queue", torch.zeros(self.pixel_memory_size, self.dim))
        self.teacher_pixel_queue = nn.functional.normalize(self.teacher_pixel_queue, p=2, dim=1)
        self.register_buffer("teacher_queue_ptr", torch.zeros(1, dtype=torch.long))
        self.register_buffer("student_pixel_queue", torch.zeros(self.pixel_memory_size, self.dim))
        self.student_pixel_queue = nn.functional.normalize(self.student_pixel_queue, p=2, dim=1)
        self.register_buffer("student_queue_ptr", torch.zeros(1, dtype=torch.long))


    @torch.no_grad()
    def concat_all_gather(self, tensor):
        """
        Performs all_gather operation on the provided tensors.
        *** Warning ***: torch.distributed.all_gather has no gradient.
        """
        tensors_gather = [torch.ones_like(tensor)
            for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

        output = torch.cat(tensors_gather, dim=0)
        return output

    
    def _dequeue_and_enqueue(self, t_keys, s_keys):
        t_keys = self.concat_all_gather(t_keys)
        s_keys = self.concat_all_gather(s_keys)
        
        batch_size, feat_dim, H, W = t_keys.size()

        this_feat = t_keys.contiguous().view(feat_dim, -1)
        
        # teacher enqueue and dequeue
        num_pixel = this_feat.shape[1]
        perm = torch.randperm(num_pixel)    
        K = min(num_pixel, self.pixel_update_freq)
        feat = this_feat[:, perm[:K]]
        feat = torch.transpose(feat, 0, 1)
        ptr = int(self.teacher_queue_ptr[0])

        if ptr + K >= self.pixel_memory_size:
            self.teacher_pixel_queue[-K:, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.teacher_queue_ptr[0] = 0
        else:
            self.teacher_pixel_queue[ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.teacher_queue_ptr[0] = (self.teacher_queue_ptr[0] + K) % self.pixel_memory_size

        # student enqueue and dequeue
        this_feat = s_keys.contiguous().view(feat_dim, -1)
        feat = this_feat[:, perm[:K]]
        feat = torch.transpose(feat, 0, 1)
        ptr = int(self.student_queue_ptr[0])

        if ptr + K >= self.pixel_memory_size:
            self.student_pixel_queue[-K:, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.student_queue_ptr[0] = 0
        else:
            self.student_pixel_queue[ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.student_queue_ptr[0] = (self.student_queue_ptr[0] + K) % self.pixel_memory_size


    def contrast_sim_kd(self, s_logits, t_logits):
        p_s = F.log_softmax(s_logits/self.contrast_kd_temperature, dim=1)
        p_t = F.softmax(t_logits/self.contrast_kd_temperature, dim=1)
        sim_dis = F.kl_div(p_s, p_t, reduction='batchmean')
        return sim_dis


    def forward(self, s_feats, t_feats):
        t_feats = F.normalize(t_feats, p=2, dim=1)
        s_feats = F.normalize(s_feats, p=2, dim=1)
        
        B, C, H, W = s_feats.shape

        self._dequeue_and_enqueue(t_feats.detach().clone(), s_feats.detach().clone())
            
        pixel_queue_size, feat_size = self.teacher_pixel_queue.shape
        perm = torch.randperm(pixel_queue_size)
        pixel_index = perm[:self.pixel_contrast_size]
        device = t_feats.device
        t_X_pixel_contrast = self.teacher_pixel_queue[pixel_index, :].to(device)
        s_X_pixel_contrast = self.student_pixel_queue[pixel_index, :].to(device)
        

        t_feats = t_feats.reshape(B, C, -1).transpose(1, 2)
        s_feats = s_feats.reshape(B, C, -1).transpose(1, 2)
        t_pixel_logits = torch.div(torch.matmul(t_feats, t_X_pixel_contrast.T), self.contrast_temperature)
        t_pixel_logits = t_pixel_logits.reshape(-1, t_pixel_logits.shape[-1])
        s_pixel_logits = torch.div(torch.matmul(s_feats, s_X_pixel_contrast.T), self.contrast_temperature)
        s_pixel_logits = s_pixel_logits.reshape(-1, s_pixel_logits.shape[-1])

        pixel_sim_dis = self.contrast_sim_kd(s_pixel_logits, t_pixel_logits.detach())
        
        loss = pixel_sim_dis * 10
        return loss

class StudentSegContrast(nn.Module):
    def __init__(self, pixel_memory_size=20000, region_memory_size=2000, region_contrast_size=1024, pixel_contrast_size=4096, 
                 contrast_kd_temperature=1.0, contrast_temperature=0.1, s_channels=1920, t_channels=1920, factor=0.1):
        super(StudentSegContrast, self).__init__()
        self.contrast_kd_temperature = contrast_kd_temperature
        self.contrast_temperature = contrast_temperature
        self.dim = t_channels

        self.project_head = nn.Sequential(
            nn.Conv2d(s_channels, t_channels, 1, bias=False),
            nn.SyncBatchNorm(t_channels),
            nn.ReLU(True),
            nn.Conv2d(t_channels, t_channels, 1, bias=False)
        )

        self.region_memory_size = region_memory_size
        self.pixel_memory_size = pixel_memory_size
        self.pixel_update_freq = 128
        self.region_update_freq = 32
        self.pixel_contrast_size = pixel_contrast_size
        self.region_contrast_size = region_contrast_size

        self.factor = factor

        '''self.register_buffer("teacher_segment_queue", torch.randn(self.region_memory_size, self.dim))
        self.teacher_segment_queue = nn.functional.normalize(self.teacher_segment_queue, p=2, dim=1)
        self.register_buffer("segment_queue_ptr", torch.zeros(1, dtype=torch.long))'''
        self.register_buffer("teacher_pixel_queue", torch.randn(self.pixel_memory_size, self.dim))
        self.teacher_pixel_queue = nn.functional.normalize(self.teacher_pixel_queue, p=2, dim=1)
        self.register_buffer("pixel_queue_ptr", torch.zeros(1, dtype=torch.long))


    @torch.no_grad()
    def concat_all_gather(self, tensor):
        """
        Performs all_gather operation on the provided tensors.
        *** Warning ***: torch.distributed.all_gather has no gradient.
        """
        tensors_gather = [torch.ones_like(tensor)
            for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

        output = torch.cat(tensors_gather, dim=0)
        return output

    
    def _dequeue_and_enqueue(self, keys):
        # segment_queue = self.teacher_segment_queue
        pixel_queue = self.teacher_pixel_queue

        keys = self.concat_all_gather(keys)
        
        batch_size, feat_dim, H, W = keys.size()

        this_feat = keys.contiguous().view(feat_dim, -1)
        
        # segment enqueue and dequeue
        '''r_feat = torch.mean(keys.contiguous().view(batch_size, feat_dim, -1), dim=2, keepdim=False).transpose(0, 1)
        num_region = r_feat.shape[1]
        perm = torch.randperm(num_region)    
        K = min(num_region, self.region_update_freq)
        r_feat = r_feat[:, perm[:K]]
        r_feat = torch.transpose(r_feat, 0, 1)
        ptr = int(self.segment_queue_ptr[0])
        if ptr + K >= self.region_memory_size:
            segment_queue[-K:, :] = nn.functional.normalize(r_feat, p=2, dim=1)
            self.segment_queue_ptr[0] = 0
        else:
            segment_queue[ptr:ptr + K, :] = nn.functional.normalize(r_feat, p=2, dim=1)
            self.segment_queue_ptr[0] = (self.segment_queue_ptr[0] + K) % self.region_memory_size'''

        # pixel enqueue and dequeue
        num_pixel = this_feat.shape[1]
        perm = torch.randperm(num_pixel)    
        K = min(num_pixel, self.pixel_update_freq)
        feat = this_feat[:, perm[:K]]
        feat = torch.transpose(feat, 0, 1)
        ptr = int(self.pixel_queue_ptr[0])

        if ptr + K >= self.pixel_memory_size:
            pixel_queue[-K:, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.pixel_queue_ptr[0] = 0
        else:
            pixel_queue[ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
            self.pixel_queue_ptr[0] = (self.pixel_queue_ptr[0] + K) % self.pixel_memory_size


    def contrast_sim_kd(self, s_logits, t_logits):
        p_s = F.log_softmax(s_logits/self.contrast_kd_temperature, dim=2)
        p_t = F.softmax(t_logits/self.contrast_kd_temperature, dim=2)
        sim_dis = F.kl_div(p_s, p_t, reduction='batchmean') * self.contrast_kd_temperature**2
        return sim_dis / p_s.shape[1]


    def forward(self, s_feats, t_feats):
        t_feats = F.normalize(t_feats, p=2, dim=1)
        s_feats = self.project_head(s_feats)
        s_feats = F.normalize(s_feats, p=2, dim=1)
        
        B, C, H, W = s_feats.shape

        self._dequeue_and_enqueue(t_feats.detach().clone())
            
        pixel_queue_size, feat_size = self.teacher_pixel_queue.shape
        perm = torch.randperm(pixel_queue_size)
        pixel_index = perm[:self.pixel_contrast_size]
        t_X_pixel_contrast = self.teacher_pixel_queue[pixel_index, :]

        t_feats = t_feats.reshape(B, C, -1).transpose(1, 2)
        s_feats = s_feats.reshape(B, C, -1).transpose(1, 2)
        t_pixel_logits = torch.div(torch.matmul(t_feats, t_X_pixel_contrast.T), self.contrast_temperature)
        s_pixel_logits = torch.div(torch.matmul(s_feats, t_X_pixel_contrast.T), self.contrast_temperature)


        '''region_queue_size, feat_size = self.teacher_segment_queue.shape
        perm = torch.randperm(region_queue_size)
        region_index = perm[:self.region_contrast_size]
        t_X_region_contrast = self.teacher_segment_queue[region_index, :]

        
        t_region_logits = torch.div(torch.matmul(t_feats, t_X_region_contrast.T), self.contrast_temperature)
        s_region_logits = torch.div(torch.matmul(s_feats, t_X_region_contrast.T), self.contrast_temperature)
        
        region_sim_dis = self.contrast_sim_kd(s_region_logits, t_region_logits.detach()) * 0.001'''
        pixel_sim_dis = self.contrast_sim_kd(s_pixel_logits, t_pixel_logits.detach()) * self.factor
        
        loss = pixel_sim_dis
        return loss

class ChannelNorm(nn.Module):
    def __init__(self):
        super(ChannelNorm, self).__init__()
    def forward(self,featmap):
        n,c,h,w = featmap.shape
        featmap = featmap.reshape((n,c,-1))
        # featmap = featmap.softmax(dim=-1)
        return featmap


class CriterionCWD(nn.Module):
    def __init__(self, s_channels, t_channels, norm_type='none',divergence='mse',temperature=1.0):
        super(CriterionCWD, self).__init__()
        # define normalize function
        if norm_type == 'channel':
            self.normalize = ChannelNorm()
        elif norm_type =='spatial':
            self.normalize = nn.Softmax(dim=1)
        elif norm_type == 'channel_mean':
            self.normalize = lambda x:x.view(x.size(0),x.size(1),-1).mean(-1)
        else:
            self.normalize = None
        self.norm_type = norm_type
        self.divergence = divergence
        # define loss function
        if divergence == 'mse':
            self.criterion = nn.MSELoss(reduction='sum')
        elif divergence == 'kl':
            self.criterion = nn.KLDivLoss(reduction='sum')
            self.temperature = temperature
        self.divergence = divergence
        self.conv = nn.Conv2d(s_channels, t_channels, kernel_size=1, bias=False)

    def forward(self,preds_S, preds_T):
        n,c,h,w = preds_S.shape
        
        # if preds_S.size(1) != preds_T.size(1):
            # preds_S = self.conv(preds_S)
        preds_S = self.conv(preds_S)

        if self.normalize is not None:
            norm_s = self.normalize(preds_S/self.temperature)
            norm_t = self.normalize(preds_T.detach()/self.temperature)
        else:
            norm_s = preds_S[0]
            norm_t = preds_T[0].detach()
        
        # if self.divergence == 'kl':
            # norm_s = norm_s.log()

        # loss = self.criterion(norm_s,norm_t)
        if self.divergence == 'mse':
            loss = nn.MSELoss(reduction='sum')
        elif self.divergence == 'kl':
            p_s = F.log_softmax(norm_s, dim=-1)
            p_t = F.softmax(norm_t, dim=-1)

            loss = F.kl_div(p_s, p_t, reduction='batchmean')
        
        if self.norm_type == 'channel' or self.norm_type == 'channel_mean':
            loss /= n * c
            # loss /= n * h * w
        else:
            loss /= n * h * w

        return loss * (self.temperature**2) * 0.1


class PatchSim(nn.Module):
    """Calculate the similarity in selected patches"""
    def __init__(self, patch_nums=256, patch_size=None, norm=True):
        super(PatchSim, self).__init__()
        self.patch_nums = patch_nums
        self.patch_size = patch_size
        self.use_norm = norm

    def forward(self, feat, patch_ids=None):
        """
        Calculate the similarity for selected patches
        """
        B, C, W, H = feat.size()
        feat = feat - feat.mean(dim=[-2, -1], keepdim=True)
        feat = F.normalize(feat, dim=1) if self.use_norm else feat / np.sqrt(C)
        query, key, patch_ids = self.select_patch(feat, patch_ids=patch_ids)
        patch_sim = query.bmm(key) if self.use_norm else torch.tanh(query.bmm(key)/10)
        if patch_ids is not None:
            patch_sim = patch_sim.view(B, len(patch_ids), -1)

        return patch_sim, patch_ids

    def select_patch(self, feat, patch_ids=None):
        """
        Select the patches
        """
        B, C, W, H = feat.size()
        pw, ph = self.patch_size, self.patch_size
        feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) # B*N*C
        if self.patch_nums > 0:
            if patch_ids is None:
                patch_ids = torch.randperm(feat_reshape.size(1), device=feat.device)
                patch_ids = patch_ids[:int(min(self.patch_nums, patch_ids.size(0)))]
            feat_query = feat_reshape[:, patch_ids, :]       # B*Num*C
            feat_key = []
            Num = feat_query.size(1)
            if pw < W and ph < H:
                pos_x, pos_y = patch_ids // W, patch_ids % W
                # patch should in the feature
                left, top = pos_x - int(pw / 2), pos_y - int(ph / 2)
                left, top = torch.where(left > 0, left, torch.zeros_like(left)), torch.where(top > 0, top, torch.zeros_like(top))
                start_x = torch.where(left > (W - pw), (W - pw) * torch.ones_like(left), left)
                start_y = torch.where(top > (H - ph), (H - ph) * torch.ones_like(top), top)
                for i in range(Num):
                    feat_key.append(feat[:, :, start_x[i]:start_x[i]+pw, start_y[i]:start_y[i]+ph]) # B*C*patch_w*patch_h
                feat_key = torch.stack(feat_key, dim=0).permute(1, 0, 2, 3, 4) # B*Num*C*patch_w*patch_h
                feat_key = feat_key.reshape(B * Num, C, pw * ph)  # Num * C * N
                feat_query = feat_query.reshape(B * Num, 1, C)  # Num * 1 * C
            else: # if patch larger than features size, use B * C * N (H * W)
                feat_key = feat.reshape(B, C, W*H)
        else:
            feat_query = feat.reshape(B, C, H*W).permute(0, 2, 1) # B * N (H * W) * C
            feat_key = feat.reshape(B, C, H*W)  # B * C * N (H * W)

        return feat_query, feat_key, patch_ids


class SpatialCorrelativeLoss(nn.Module):
    """
    learnable patch-based spatially-correlative loss with contrastive learning
    """
    def __init__(self, loss_mode='cos', patch_nums=256, patch_size=32, norm=True, use_conv=True, T=0.1, input_nc=128):
        super(SpatialCorrelativeLoss, self).__init__()
        self.patch_sim = PatchSim(patch_nums=patch_nums, patch_size=patch_size, norm=norm)
        self.patch_size = patch_size
        self.patch_nums = patch_nums
        self.norm = norm
        self.use_conv = use_conv
        self.conv_init = False
        self.loss_mode = loss_mode
        self.T = T
        self.criterion = nn.L1Loss() if norm else nn.SmoothL1Loss()
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        if use_conv:
            output_nc = max(32, input_nc // 4)
            self.conv = nn.Sequential(*[nn.Conv2d(input_nc, output_nc, kernel_size=1),
                                nn.ReLU(),
                                nn.Conv2d(output_nc, output_nc, kernel_size=1)])

    def cal_sim(self, f_src, f_tgt, patch_ids=None):
        """
        calculate the similarity map using the fixed/learned query and key
        :param f_src: feature map from source domain
        :param f_tgt: feature map from target domain
        :param f_other: feature map from other image (only used for contrastive learning for spatial network)
        :return:
        """
        if self.use_conv:
            f_src, f_tgt = self.conv(f_src), self.conv(f_tgt)
        sim_src, patch_ids = self.patch_sim(f_src, patch_ids)
        sim_tgt, patch_ids = self.patch_sim(f_tgt, patch_ids)

        return sim_src, sim_tgt

    def compare_sim(self, sim_src, sim_tgt):
        """
        measure the shape distance between the same shape and different inputs
        :param sim_src: the shape similarity map from source input image
        :param sim_tgt: the shape similarity map from target output image
        :param sim_other: the shape similarity map from other input image
        :return:
        """
        B, Num, N = sim_src.size()
        if self.loss_mode == 'info':
            sim_src = F.normalize(sim_src, dim=-1)
            sim_tgt = F.normalize(sim_tgt, dim=-1)
            sam_self = (sim_src.bmm(sim_tgt.permute(0, 2, 1))).view(-1, Num) / self.T
            loss = self.cross_entropy_loss(sam_self, torch.arange(0, sam_self.size(0), dtype=torch.long, device=sim_src.device) % (Num))
        else:
            tgt_sorted, _ = sim_tgt.sort(dim=-1, descending=True)
            num = int(N / 4)
            src = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_src, sim_src)
            tgt = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_tgt, sim_tgt)
            if self.loss_mode == 'l1':
                loss = self.criterion((N / num) * src, (N / num) * tgt)
            elif self.loss_mode == 'cos':
                sim_pos = F.cosine_similarity(src, tgt, dim=-1)
                loss = self.criterion(torch.ones_like(sim_pos), sim_pos)
            else:
                raise NotImplementedError('padding [%s] is not implemented' % self.loss_mode)

        return loss

    def forward(self, f_src, f_tgt):
        """
        calculate the spatial similarity and dissimilarity loss for given features from source and target domain
        :param f_src: source domain features
        :param f_tgt: target domain features
        :param f_other: other random sampled features
        :param layer:
        :return:
        """
        sim_src, sim_tgt = self.cal_sim(f_src, f_tgt)
        # calculate the spatial similarity for source and target domain
        loss = self.compare_sim(sim_src, sim_tgt)
        return loss

class Normalize(nn.Module):

    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm + 1e-7)
        return out

class SRC_Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, feat_q, feat_k):
        '''
        :param feat_q: target
        :param feat_k: source
        :return: SRC loss, weights for hDCE
        '''

        batchSize = feat_q.shape[0]
        dim = feat_q.shape[1]
        feat_k = feat_k.detach()
        batch_dim_for_bmm = batchSize   # self.opt.batch_size
        feat_k = Normalize()(feat_k)
        feat_q = Normalize()(feat_q)

        ## SRC
        feat_q_v = feat_q.view(batch_dim_for_bmm, -1, dim)
        feat_k_v = feat_k.view(batch_dim_for_bmm, -1, dim)

        spatial_q = torch.bmm(feat_q_v, feat_q_v.transpose(2, 1))
        spatial_k = torch.bmm(feat_k_v, feat_k_v.transpose(2, 1))

        spatial_q = nn.Softmax(dim=1)(spatial_q)
        spatial_k = nn.Softmax(dim=1)(spatial_k).detach()

        loss_src = self.get_jsd(spatial_q, spatial_k)

        return loss_src * 0.01

    def get_jsd(self, p1, p2):
        '''
        :param p1: n X C
        :param p2: n X C
        :return: n X 1
        '''
        m = 0.5 * (p1 + p2)
        out = 0.5 * (nn.KLDivLoss(reduction='batchmean', log_target=True)(torch.log(m), torch.log(p1))
                     + nn.KLDivLoss(reduction='batchmean', log_target=True)(torch.log(m), torch.log(p2)))
        return out
