
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from utils.cmk_mmd import CMK_MMD


class MemoryBank(nn.Module):
    def __init__(self, num_features, num_samples, N=0, K=0, temp=0.05, momentum=0.2):
        super(MemoryBank, self).__init__()
        self.num_features = num_features
        self.num_samples = num_samples
        # self.args = args

        self.momentum = momentum
        self.temp = temp

        self.register_buffer('features', torch.zeros(num_samples, num_features))
        # the source-like samples labels
        self.register_buffer('labels', torch.zeros(num_samples).long())
        # the psuedo-labels
        self.register_buffer('pred_labels', torch.zeros(num_samples).long())
        
        self.register_buffer('easy_samples', torch.zeros(K, num_features))
        self.register_buffer('hard_samples', torch.zeros(K, num_features))

        self.register_buffer('easy_labels', torch.zeros(N).long())
        self.register_buffer('hard_labels', torch.zeros(N).long())


    def forward(self, disargeed_tar_idx, args):
        # the centorid can use the pre evalation or recalculate it in this place.
        if len(disargeed_tar_idx) > 0:            
            mmd_align_with_label = CMK_MMD(kernel_mul=2.0, kernel_num=5, fix_sigma=None).cuda()
    
            cmk_mmk_loss = mmd_align_with_label(self.easy_samples, self.hard_samples,self.easy_labels, self.hard_labels)
            classifier_loss = cmk_mmk_loss * args.cmk_mmd_co          
        else:
            classifier_loss = torch.tensor(0.0).cuda()
        return classifier_loss

if __name__ == '__main__':
    pass


     
