import torch
import torch.nn as nn
import numpy as np

class CMK_MMD(nn.Module):
    def __init__(self, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super(CMK_MMD, self).__init__()
        self.kernel_mul = kernel_mul
        self.kernel_num = kernel_num
        self.fix_sigma = fix_sigma

    def mmd_loss(self, x, y, kernel_func, source_labels, target_labels ):
        """
        Compute the MMD loss funcation.        
        """
        xx = kernel_func(x, x,  source_labels, source_labels)
        yy = kernel_func(y, y,  target_labels, target_labels)
        xy = kernel_func(x, y,  source_labels, target_labels)
        loss =torch.mean(xx) + torch.mean(yy) - 2 * torch.mean(xy)
        return loss 
    
    def gaussian_kernel_with_label(self, source, target, source_label=None, target_label=None, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0) 
    
        if source_label is not None and target_label is not None:
            total_label = torch.cat([source_label, target_label], dim=0) #.cuda() 
        else:
            total_label = None
    
        # extend metrix to the form of (N,1,D) or (1,N,D) 
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    
        L2_distance = ((total0-total1)**2).sum(2)

        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance)/(n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = []
        for bandwidth_temp in bandwidth_list:
            # The kernel function is modified to add the category information to the Gaussian kernel
            if total_label is not None:
                kernel_temp = torch.exp(-L2_distance / bandwidth_temp)
                kernel_val.append(torch.sum(kernel_temp * (total_label.unsqueeze(0) != total_label.unsqueeze(1)).float()))
            else:
                kernel_val.append(torch.exp(-L2_distance / bandwidth_temp))
        return sum(kernel_val)
    
        # Define the MMD domain alignment function with conditional information
    def forward(self, source_features, target_features, source_labels, target_labels, lambda_mmd=1.0):

        # Calculate the MMD loss and add the conditional information
        mmd_loss_val = self.mmd_loss(source_features, target_features, self.gaussian_kernel_with_label, source_labels, target_labels)
        # The MMD loss and cross-entropy loss are combined to obtain the total loss function
        mean_cmk_mmk_loss = lambda_mmd *  mmd_loss_val / (len(source_labels)*len(target_labels))
        # mean_cmk_mmk_loss = torch.clamp(mean_cmk_mmk_loss, min=1e-6, max=10 - 1e-6)
        loss = mean_cmk_mmk_loss

        return loss


if __name__ == "__main__":
    pass
