import torch
import torch.nn as nn


def _sims(user_emb, item_emb): #用户嵌入和项目嵌入之间的相似度矩阵
    """
    n_embs is the number of embedding vecters per user
    :param user_emb: tensor, shape [batch_size, n_embs, emb_size]
    :param item_emb: tensor, shape [batch_size, n_items, emb_size]
    :return: similarity prediction [batch_size, n_items, n_embs]
    """
    # print(torch.norm(user_emb, dim=-1).shape)
    user_emb_norm = user_emb / torch.clamp(
        torch.norm(user_emb, dim=-1, keepdim=True), 1e-7)
    item_emb_norm = item_emb / torch.clamp(
        torch.norm(item_emb, dim=-1, keepdim=True), 1e-7)
    # print(user_emb_norm.shape, item_emb_norm.shape)
    user_4d = torch.unsqueeze(user_emb_norm, dim=1)
    item_4d = torch.unsqueeze(item_emb_norm, dim=2)
    cos_sim = torch.sum(torch.mul(user_4d, item_4d), dim=-1, keepdim=True)#逐元素相乘得到相似度矩阵
    # print(cos_sim.shape)
    return cos_sim#相似度矩阵


class SimilarityOutputLayer(nn.Module):#计算用户和物品之间的相似度得分，并将其压缩为0到1之间的概率值
    def __init__(self):
        super(SimilarityOutputLayer, self).__init__()
        self._projection_layer = nn.Sequential(
            nn.Linear(1, 1, True), nn.Sigmoid())

    def forward(self, user_emb, item_emb):
        """
        n_embs is the number of embedding vecters per user
        :param user_emb: tensor, shape [batch_size, n_embs, emb_size]
        :param item_emb: tensor, shape [batch_size, n_items, emb_size]
        :return: similarity prediction [batch_size, n_items, n_embs]
        """
        # print(torch.norm(user_emb, dim=-1).shape)
        cos_sim = _sims(user_emb, item_emb)
        return torch.squeeze(self._projection_layer(cos_sim), dim=-1)


class MaxAggregator(nn.Module):#计算用户和物品之间的相似度，然后执行最大池化操作，返回相似度矩阵每一行的最大值
    def __init__(self):
        super(MaxAggregator, self).__init__()
        self._output_layer = SimilarityOutputLayer()#计算用户和物品之间的相似度得分，并将其压缩为0到1之间的概率值

    def forward(self, user_emb, item_emb, last_item_mask=None):
        sims = self._output_layer(user_emb, item_emb)
        # print("max aggregator", sims.shape, user_emb.shape, item_emb.shape)
        if last_item_mask is None:
            return sims.max(dim=-1, keepdim=False)[0]#计算相似度矩阵sims的每一行中的最大值
        last_item_mask = last_item_mask.unsqueeze(dim=1)#调整维度匹配sim
        sims *= torch.where(last_item_mask == 1,
                            last_item_mask,
                            torch.ones_like(last_item_mask) * -10000)
        return sims.max(dim=-1, keepdim=False)[0]


class WeightedMaxAggregator(nn.Module):#根据给定的权重和掩码，聚合用户和项目的嵌入，最终输出每个批次中的最大相似度值
    #捕捉用户和项目之间的复杂关系
    def __init__(self):
        super(WeightedMaxAggregator, self).__init__()
        self._projection_layer = nn.Sequential( #线性投影+归一化
            nn.Linear(1, 1, True), nn.Sigmoid())

    def forward(self, user_emb, item_emb, cluster_weights, last_item_mask):#模型在接收输入并进行计算时所执行的操作
        """
        :param user_emb: [batch_size, seq_len, emb_size]
        :param item_emb: [batch_size, n_items, emb_size]
        :param cluster_weights: [batch_size, seq_len, 1]
        :param last_item_mask: [batch_size, 1, seq_len]
        :return:
        """
        # print("cluster weight", cluster_weights.shape)
        sims = _sims(user_emb, item_emb)#求相似度
        # print("sims shape", sims.shape, 'cluster weights', cluster_weights.shape)
        sims = sims * cluster_weights.unsqueeze(dim=2)#相似度*聚类权重
        # print("sims shape", sims.shape)
        # print("last item mask shape", last_item_mask.shape)
        #last_item_mask = last_item_mask.unsqueeze(dim=1)
        sims = self._projection_layer(sims).squeeze(dim=-1)#通过_projection_layer投影这些相似度，然后使用Sigmoid函数进行归一化
        # print("sims shape", sims.shape, "last item shape", last_item_mask.shape)
        sims = sims.permute(0, 2, 1) * last_item_mask.to('cuda:0')
        sims = sims.max(dim=2, keepdim=False)[0] #从结果中提取每个批次中的最大值作为输出
        # print("sims final shape", sims.shape)
        return sims##返item和user的相似度*聚类相似度

        # return self._projection_layer(
        #     sims.max(dim=-1, keepdim=False)[0])
