import torch

def MMD_loss(source, target):
    classwise_mean_source = torch.mean(source, dim=-2)
    classwise_mean_target = torch.mean(target, dim=-2)
    # print("classwise_mean_source: ", classwise_mean_source)
    # print("classwise_mean_target: ", classwise_mean_target)
    mmd = torch.sum((classwise_mean_source - classwise_mean_target) ** 2)
    return mmd

