import os
from typing import Union
import gc
import numpy

from modules.attention import MHGlobalQueryAttention, MultiHeadAttention
from modules.loss_funcs import CrossEntropyLoss, TripletLoss
from modules.output_layers import MaxAggregator, WeightedMaxAggregator
import torch
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn import cluster
from modules.attention import OneHot, RelValue, SineEnc
from modules.dgr import DGR





def _get_cluster_m(n_clusters, method):
    if method == 'ward':
        return AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
    elif method == 'dbscan':
        return cluster.DBSCAN(eps=.2)
    elif method == 'spectrum':
        return cluster.SpectralClustering(
            n_clusters=n_clusters,
            eigen_solver='arpack',
            affinity="nearest_neighbors")
    elif method == 'birch':
        return cluster.Birch(n_clusters=n_clusters)
    elif method == 'kmeans':
        return cluster.KMeans(n_clusters=n_clusters, random_state=0)
    else:
        raise NotImplementedError(f"Cluster method {method}")
class TisasAttention:
    pass


def apply_dgr_model(group):
    A=DGRModel().to('cpu')


    return A. _build_seq_graph(group)


class _AttentionModel(nn.Module):
    def __init__(self, hps, attention_cls):
        """
        :param hps: dict
        :param attention_cls: [MHGlobalQueryAttention,
                               MultiHeadAttention,
                                ComiRecAttention]
        hps：一个字典，包含了模型的超参数（例如，激活函数类型、嵌入维度、注意力机制等）
        """
        super(_AttentionModel, self).__init__()

        if hps['activation'] == 'relu':#选择激活函数
            act_fn = nn.ReLU
        elif hps["activation"] == 'sigmoid':
            act_fn = nn.Sigmoid
        elif hps["activation"] == 'tanh':
            act_fn = nn.Tanh
        else:
            act_fn = nn.Identity#不改变网络 返回输出  只起到加深网络的作用

        if "item_count" in hps: #判断选择emd
            self._item_embedding = nn.Embedding(
                num_embeddings=hps["item_count"],#嵌入总数
                embedding_dim=hps["emb_dim"])
        else:
            self._item_embedding = None

        hps["input_dim"] = hps["emb_dim"]
        if hps["att_on_tem"]:
            if hps["tem_enc"] is None:
                self._tem_enc_layer = RelValue()#时间戳转换为时间间隔归一化表示
            elif hps["tem_enc"] == "one_hot":
                self._tem_enc_layer = OneHot(max_k=hps["tem_dim"] - 1)#时间间隔分类 实现多尺寸
            elif hps["tem_enc"] == "sine":
                from modules.attention import SineEnc
                self._tem_enc_layer = SineEnc(dim=hps["tem_dim"])#使用正弦、余弦函数的位置编码（PR）编码时间戳
            else:
                raise AttributeError("unknown encoding layer {hps['tem_enc']}")
            hps["input_dim"] += self._tem_enc_layer.output_dim
            #确保模型的输入维度包括时间编码层的输出维度
            #这是因为时间编码层的输出将与嵌入数据连接在一起，以提供完整的输入特征。
        else:
            self._tem_enc_layer = None

        if hps["att_on_pos"]:
            from modules.attention import SineEnc
            if attention_cls == TisasAttention:####
                hps["pos_dim"] = hps["emb_dim"]
            else:
                hps["input_dim"] += hps["pos_dim"]
            self._pos_enc = SineEnc(dim=hps["pos_dim"])#位置信息
        else:
            self._pos_enc = None

        self._attention = attention_cls(
            d_model=hps["d_model"],
            emb_size=hps["emb_dim"],
            input_dim=hps["input_dim"],
            n_head=hps["n_head"],
            dropout=hps["dropout"],
            activation=act_fn,
            share_query=hps["share_query"],
            seq_len=hps["seq_len"]
        )
        self._sims = MaxAggregator()#计算用户和物品之间的相似度，然后执行最大池化操作，返回相似度矩阵每一行的最大值

        if hps["loss_fn"] == "triplet":
            self._loss_fn = TripletLoss(hps["loss_margin"])#三元组损失
        else:
            self._loss_fn = CrossEntropyLoss()#二进制交叉熵损失
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to(self.device)
    def _add_optim(self, hps):#优化器
        self._optimizer = Adam(self.parameters(), lr=hps['lr'] )#adam优化器
        self._hps = hps
        if torch.cuda.is_available():
            self._device = torch.device('cuda:0')
            self.to(self._device)
        else:
            self._device = torch.device('cpu')

    def forward(self, seq, pos, neg, timestamps=None):#正负样本相似性得分  2

        pos = self._embed_item(pos)
        neg = self._embed_item(neg)
        if torch.any(torch.isnan(pos)):
            print("pos embedding has nan", pos)
            print("pos sequence", pos)
        if torch.any(torch.isnan(neg)):
            print("neg embeddings has nan", neg)
            print("neg sequence", neg)
        # print("forward got timestmaps", timestamps.shape if timestamps is not None else None)
        user_embs = self._embed_user(seq, timestamps)
        if torch.any(torch.isnan(user_embs)):
            print("user embeddings has nan", user_embs)

        # print("user_embedding shape", user_embs.shape)
        pos_sims = self._sims(user_embs, pos)# 正样本vs用户嵌入 得分
        neg_sims = self._sims(user_embs, neg)# 负样本vs用户嵌入 得分
        if torch.any(torch.isnan(pos_sims)):
            print('pos sims has nan', pos_sims)
        if torch.any(torch.isnan(neg_sims)):
            print('neg sims has nan', neg_sims)
        return pos_sims, neg_sims

    def _embed_user(self, seq, timestamps=None):#用户嵌入 seq：表示用户的历史行为序列； timestamps：表示时间戳数据
        if timestamps is not None and self._tem_enc_layer is not None:#判断是否存在时间编码层
            time_enc = self._tem_enc_layer(timestamps)
        else:
            time_enc = None
        if self._pos_enc is not None:#判断是否存在位置编码层
            positions = torch.arange(0, seq.shape[1]).to(self._device)#建立序列索引
            positions = torch.reshape(positions, [1, -1, 1])#[1, 序列长度, 1]
            pos_enc = self._pos_enc(positions.repeat(seq.shape[0], 1, 1))#[seq.shape，1，1]
        else:
            pos_enc = None
        seq_emb = self._embed_item(seq)

        user_emb = self._attention(seq_emb, time_enc=time_enc,#时间、位置、序列 用作用户建模 u、t、p
                                   pos_enc=pos_enc, causality=False)

        return user_emb

    def _embed_item(self, item):#项目嵌入  3

        if self._item_embedding is None:
            return item.float()
        else:
            device = self._item_embedding.weight.device
            item = item.to(device)
            return self._item_embedding(item.long())



    def _get_loss(self, pos_logits, neg_logits):#参数：正样本和负样本的模型预测得分  计算损失
        truth_1 = torch.ones_like(pos_logits)
        truth_0 = torch.zeros_like(neg_logits)
        y_truth = torch.cat([truth_1, truth_0], dim=1)#torch.cat 张量合并
        y_pred = torch.cat([pos_logits, neg_logits], dim=1)
        #print('y_truth shape', y_truth.shape, "y_pred shape", y_pred.shape)
        return self._loss_fn(y_pred, y_truth)

    def supervised(self, seq, timestamps, pos, neg, train=True):#返回正负样本预测值及损失 梯度调整    111

        """在监督学习任务中训练模型，通过最小化损失函数来优化模型参数
        :param seq:
        :param timestamps:
        :param pos_enc: position encoding
        :param pos:
        :param neg:
        :param train:
        :return:
        """
        pos_logits, neg_logits = self.forward(seq, pos, neg, timestamps)#正负样本相似性得分
        loss = self._get_loss(pos_logits, neg_logits)#计算损失
        if train:
            self._optimizer.zero_grad()#将优化器（self._optimizer）的梯度清零，以准备进行反向传播计算新的梯度
            loss.backward()#计算损失值对模型参数的梯度
            #检查参数是否存在NaN值 存在打印
            res = [d.grad for d in self.parameters() if
                   d.grad is not None and torch.any(torch.isnan(d.grad))]
            if len(res):
                print("Nan grad encountered", res)
                return pos_logits, neg_logits, loss
            torch.nn.utils.clip_grad_norm_(self.parameters(), 10.0)#函数对模型的梯度进行梯度裁剪，以防止梯度爆炸
            self._optimizer.step()#使用优化器进行参数更新
        return pos_logits, neg_logits, loss#返回正负样本预测值及损失

    def clustered_inference(self, seq, timestamps, pos, neg,# 返回正负相似性 损失
                            pos_enc=None,
                            n_clusters=5,#10#表示要聚类成的簇的数量
                            selection: str ='last',#表示选择聚类中心的策
                            method='ward'):


        with torch.no_grad():
            user_emb = self._embed_user(seq, timestamps)
            user_emb_np = user_emb.cpu().numpy() #用户嵌入向量转换为 NumPy 数组
            centers = np.empty([user_emb_np.shape[0], #np.empty 用于创建一个未初始化的数组，其形状由第一个参数指定 创建一个未初始化的NumPy数组来存储每个数据点的聚类中心。
                n_clusters, user_emb_np.shape[2]])#[batch_size, n_clusters, embedding_dim]存储聚类中心点的嵌入向量
            iterations=10
            A=user_emb_np
            labels = []
            for iteration in range(iterations):

                # 创建一个与flat_seq_embed对应的索引数组
                original_shape = user_emb_np.shape  # 正确获取形状 # 原始形状[B, L, F]
                total_elements = user_emb_np.shape[0] * user_emb_np.shape[1]  # 正确获取元素总数
                indices = np.arange(total_elements)  # 规定长度匹配展开后的长度

                # 步骤1: 使用K-means进行聚类
                labels = []
                for bid in range(user_emb_np.shape[0]):
                    cluster_m = _get_cluster_m(n_clusters, method)
                    cluster_m.fit(user_emb_np[bid])
                    labels = labels + list(cluster_m.labels_)

                # 步骤2: 根据聚类结果分组
                labels_np = np.array(labels)

                flat_seq_embed = user_emb_np.reshape(user_emb_np.shape[0] * user_emb_np.shape[1], 32)
                if isinstance(flat_seq_embed, np.ndarray):
                    flat_seq_embed_np = flat_seq_embed  # flat_seq_embed已经是numpy数组
                else:
                    # 假设flat_seq_embed是PyTorch张量，将其转换为numpy数组
                    flat_seq_embed_np = flat_seq_embed.detach().cpu().numpy()

                groups = [flat_seq_embed_np[labels_np == i] for i in range(n_clusters)]

                # 步骤3: 对每组数据应用dgr模型优化
                optimized_groups = [apply_dgr_model(group) for group in groups]

                # 初始化一个空数组用于存储重组后的数据
                reconstructed_seq = torch.zeros((total_elements, A.shape[2]))

                for i in range(n_clusters):
                    # 找到当前聚类的所有原始索引位置
                    cluster_indices = indices[labels_np == i]
                    # 将优化后的组数据放回它们原始的位置
                    for j, index in enumerate(cluster_indices):
                        reconstructed_seq[index, :] = optimized_groups[i][j][0]

                # 步骤4: 将优化后的数据重新组合
                reconstructed_seq = reconstructed_seq.reshape(original_shape)
                labels=[]

                user_emb_np = reconstructed_seq
                del labels_np, flat_seq_embed_np, groups, optimized_groups, reconstructed_seq, cluster_indices
                gc.collect()

            for bid in range(user_emb_np.shape[0]):
                # cluster_m = get_model() #根据指定方法进行聚类 将用户分为n_clusters个簇  返回值为一个聚类方法模型
                # c_labels = cluster_m._build_seq_graph(user_emb_np[bid]) #user_emb_np[bid] 是指当前批次（batch）中的用户嵌入向量
                cluster_m = _get_cluster_m(n_clusters, method)
                c_labels = cluster_m.fit_predict(user_emb_np[bid])
                # .fit_predict方法首先对数据进行拟合（fit），即根据提供的数据调整聚类模型的参数，
                # 然后对同一数据集执行预测（predict），即将数据点分配到这些聚类中
                distinct = np.unique(c_labels)

                #这行代码使用 NumPy 的 unique 函数来找出 c_labels 数组中的所有不同值。
                #distinct 数组包含了当前批次中识别出的所有独特聚类标签
                #这意味着如果聚类模型将数据分成了几个不同的簇，这些簇的每个唯一标签都将出现在 distinct 数组中。
                # print("before cluster shape", user_emb_np.shape)
                # print("center shape", centers.shape)
                for cid, c in enumerate(distinct):#找到每个簇的代表性数据点（中心点）
                    cluster = np.take(
                        user_emb_np[bid], np.where(c_labels == c)[0], axis=0)#从用户嵌入向量中选取属于当前聚类标签c的那些向量
                    if cluster.shape[0] <= 2:
                        # when cluster contains two nodes, there's no medoid  当集群包含小于两个节点时，没有medoid
                        centers[bid][cid] = cluster[-1] #最后一个点设为中心点
                        continue
                    # print("cluster shape", cluster.shape)
                    if selection == 'last':
                        centers[bid][cid] = cluster[-1]#last下也设置最后一个为中心点
                    else:
                        pairwise = squareform(pdist(cluster)) #计算簇内所有点之间的成对距离
                        # print("pairwise", pairwise.shape)
                        row_sum = np.sum(pairwise, axis=0)#计算每个点与簇内其他点的距离总和
                        # print("row sum", row_sum)
                        idx = np.argmin(row_sum)#找出距离总和最小的点的索引，这个点被认为是簇的“中点”（medoid）
                        # print("medoid id", idx)
                        centers[bid][cid] = cluster[idx]#距离和最小点设置为中心点

                # for those cluster methods that return 对于那些返回的集群方法
                # arbitrary number of clusters  任意数量的簇
                for cid in range(len(distinct), n_clusters):#识别出簇数量小于预期 不足最小值的部分被赋值
                    centers[bid][cid] = centers[bid][cid % len(distinct)]#用已有簇中心点填充  5 7  用1 2 填充 6 7簇
                #print("clustered user emb", user_emb.shape, centers.shape)
                #print(centers)
            user_emb = torch.from_numpy(centers).float().to(self._device)#将聚类中心转换回PyTorch张量，并转移到适当的设备上（如GPU）
            pos_emb = self._embed_item(pos)
            neg_emb = self._embed_item(neg)
            pos_sims = self._sims(user_emb, pos_emb)
            neg_sims = self._sims(user_emb, neg_emb)
            #print(pos_sims, neg_sims)
            loss = self._get_loss(pos_sims, neg_sims)
            del user_emb_np, centers, cluster_m, c_labels, distinct, cluster, pairwise, row_sum, idx
            gc.collect()
        return pos_sims, neg_sims, loss

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad) #当前模型中所有可训练参数的总数量，以整数形式表示。


class GlobalQueryModel(_AttentionModel):
    def __init__(self, hps):
        super(GlobalQueryModel, self).__init__(hps, MHGlobalQueryAttention)
        self._add_optim(hps)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

class StaticPinnerSagePlus(_AttentionModel):
    def __init__(self, hps):
        super(StaticPinnerSagePlus, self).__init__(hps, MultiHeadAttention)
        self._add_optim(hps)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to(self.device)


class WeightedPinnerSagePlus(_AttentionModel):
    def __init__(self, hps):
        hps["att_on_tem"] = True
        hps["tem_enc"] = 'sine'
        super(WeightedPinnerSagePlus, self).__init__(hps, MultiHeadAttention)
        # overwrite
        self._sims = WeightedMaxAggregator()#创建了一个 WeightedMaxAggregator 实例  返回用户和项目复杂关系
        self._weight_tem_enc = OneHot(max_k=hps["tem_dim"])# 创建了一个 OneHot 编码器实例，用于时间编码
        input_dim = hps["seq_len"] * self._weight_tem_enc.output_dim + hps["emb_dim"] #序列长度*时间维度+嵌入维度=输入维度
        # print('weight input dim', input_dim)
        self._weight_model = nn.Sequential(#并创建一个线性模型 _weight_model，用于学习权重 线性+sigmoid+线性
            nn.Linear(in_features=input_dim,
                      out_features=128),
            nn.Sigmoid(),
            nn.Linear(in_features=128, out_features=1),
            nn.Softplus()#relu函数的平滑版本
            )
        self._add_optim(hps)#优化器 设备设置
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    #这个load_unweighted函数的目的是从一个保存的模型检查点（checkpoint）文件中加载未加权的模型参数
    #一个检查点文件中加载模型参数，同时确保当前模型只更新它实际有的参数。这对于模型迁移或继续训练已有模型非常有用
    def load_unweighted(self, ckpt):#加载未加权模型
        state_dict = torch.load(ckpt).state_dict()#加载检查点文件 .state_dict()方法获取模型的状态字典 包含模型的参数和缓冲区
        own_state = self.state_dict()#获取当前模型的状态字典
        print("The following parameters are loaded from {ckpt}")
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters  序列化参数的向后兼容性
                param = param.data# 获取参数的数据部分
            try:
                own_state[name].copy_(param)#尝试将检查点中的参数复制到当前模型的相应参数中。
                print(name)
            except Exception as e:
                print(f"failed to load {name} due to:\n {e}")
        print("\n")

    def _cluster_items(self, seq_embed, n_clusters=5, method='ward'):
        cluster_mask = torch.zeros(
             [seq_embed.shape[0], seq_embed.shape[1], seq_embed.shape[1]])
        cluster_mask1 = torch.ones((seq_embed.shape[0], seq_embed.shape[1], seq_embed.shape[1]))



        last_item_mask = torch.zeros(
            [seq_embed.shape[0], 1, seq_embed.shape[1]])
        with torch.no_grad():
            iterations = 10

            for iteration in range(iterations):
                original_shape = seq_embed.shape
                total_elements = seq_embed.shape[0] * seq_embed.shape[1]
                indices = np.arange(total_elements)

                labels = []
                for bid in range(seq_embed.shape[0]):
                    cluster_m = _get_cluster_m(n_clusters, method)
                    cluster_m.fit(seq_embed[bid].detach().cpu())
                    labels = labels + list(cluster_m.labels_)

                labels_np = np.array(labels)

                flat_seq_embed = seq_embed.reshape(seq_embed.shape[0] * seq_embed.shape[1], 32)
                if isinstance(flat_seq_embed, np.ndarray):
                    flat_seq_embed_np = flat_seq_embed
                else:
                    flat_seq_embed_np = flat_seq_embed.detach().cpu().numpy()

                groups = [flat_seq_embed_np[labels_np == i] for i in range(n_clusters)]

                optimized_groups = [apply_dgr_model(group) for group in groups]

                reconstructed_seq = torch.zeros((total_elements, seq_embed.shape[2]))

                for i in range(n_clusters):
                    cluster_indices = indices[labels_np == i]
                    for j, index in enumerate(cluster_indices):
                        reconstructed_seq[index, :] = optimized_groups[i][j][0]

                reconstructed_seq = reconstructed_seq.reshape(original_shape)
                seq_embed = reconstructed_seq

                # 清除临时变量以释放内存
                del labels_np, flat_seq_embed_np, groups, optimized_groups, reconstructed_seq, cluster_indices
                gc.collect()

        for bid in range(seq_embed.shape[0]):
            cluster_m = _get_cluster_m(n_clusters, method)
            c_labels = cluster_m.fit_predict(seq_embed[bid].detach().cpu())
            distinct = np.unique(c_labels)
            for cid, c in enumerate(distinct):
                item_idx = np.where(c_labels == c)[0]
                for i in item_idx:
                    for j in item_idx:
                        cluster_mask[bid, i, j] = 1
                last_item_mask[bid, 0, item_idx[-1]] = 1

        del seq_embed, original_shape, total_elements, labels
        gc.collect()
        # print(cluster_mask1.shape)
        # print(cluster_mask.shape)

        return cluster_mask, last_item_mask
    def _compute_weights(self, cluster_emb, timestamps, cluster_mask):#返回计算得到的聚类权重  定义了一个方法来计算聚类的权重  时间间隔信息结合聚类嵌入
        itvl = timestamps.squeeze(dim=-1).unsqueeze(dim=1) - timestamps#每对时间之间的时间差
        # print("itvl shape", itvl.shape)
        itvl_enc = self._weight_tem_enc(time_itvls=torch.abs(itvl))#one-hot  时间尺寸（时间间隔分类编码）  torch.abs 绝对值
        # print("itvl enc", itvl_enc.shape)
        # [batch_size, seq_len, seq_len, tem_dim]

        itvl_enc = itvl_enc * cluster_mask.unsqueeze(dim=-1).to(self.device) #确保只有相同聚类内的物品对时间间隔被考虑
        itvl_enc = itvl_enc.reshape([itvl.shape[0], itvl.shape[1], -1])#[batch_size, seq_len, seq_len * tem_dim] 重塑时间间隔编码，以便能够与聚类嵌入合并
        # print("itvl enc reshape", itvl_enc.shape)
        # [batch_size, seq_len, seq_len * tem_dim]
        x = torch.cat([itvl_enc, cluster_emb], dim=-1)#将时间间隔编码和聚类嵌入拼接在一起，形成模型的输入
        # print("input x shape", x.shape)
        cluster_weights = self._weight_model(x)#计算聚类的权重模型
        # print("cluster weights", cluster_weights.shape)
        # print("cluster weights", cluster_weights)
        return cluster_weights#返回计算得到的聚类权重
    #综合考虑了时间、位置、物品特性和用户的交互模式，用于生成用户嵌入并计算与物品的相似度
    #这种方法可能用于推荐系统或其他序列数据处理任务，其中用户的历史交互和时间因素对预测结果有重要影响



    def forward(self, seq, pos, neg, timestamps, n_clusters=5, method='ward'):#模型的前向传播方法
        #检查时间戳以及正负样本
        assert timestamps is not None, "Weighted model requires timestamps"
        pos = self._embed_item(pos)
        neg = self._embed_item(neg)
        if torch.any(torch.isnan(pos)):
            print("pos embedding has nan", pos)
            print("pos sequence", pos)
        if torch.any(torch.isnan(neg)):
            print("neg embeddings has nan", neg)
            print("neg sequence", neg)

        # embed user 嵌入用户序列
        time_enc = self._tem_enc_layer(timestamps)
        if self._pos_enc is not None:#位置编码
            positions = torch.arange(0, seq.shape[1]).to(self._device)
            positions = torch.reshape(positions, [1, -1, 1])
            pos_enc = self._pos_enc(positions.repeat(seq.shape[0], 1, 1))
        else:
            pos_enc = None
        seq_emb = self._embed_item(seq)#seq物品嵌入
        cluster_mask, last_item_mask = self._cluster_items(
            seq_emb, n_clusters=n_clusters, method=method)
        user_embs = self._attention(seq_emb, time_enc=time_enc,
                                   pos_enc=pos_enc, causality=False,
                                   cluster_mask=cluster_mask)

        if torch.any(torch.isnan(user_embs)):
            for bid in range(user_embs.shape[0]):
                if not torch.any(torch.isnan(user_embs[bid])):
                    continue
                # print("user embeddings has nan", user_embs[bid])
                # print("cluster mask", cluster_mask[bid])

        cluster_weight = self._compute_weights(
                user_embs, timestamps, cluster_mask)
        # [batch_size, seq_len, 1]
        self.last_embs = user_embs
        self.last_weight = cluster_weight
        self.last_cluster = cluster_mask

        #print("user_embedding shape", user_embs.shape)
        pos_sims = self._sims(user_embs, pos, cluster_weight, last_item_mask)
        neg_sims = self._sims(user_embs, neg, cluster_weight, last_item_mask)
        # print("pos sims", pos_sims)
        # print("neg sims", neg_sims)
        if torch.any(torch.isnan(pos_sims)):
            print('pos sims has nan', pos_sims)
        if torch.any(torch.isnan(neg_sims)):
            print('neg sims has nan', neg_sims)
        return pos_sims, neg_sims#返回正负样本的相似度


    def supervised(self, seq, timestamps, pos, neg, train=True,#一个在监督学习环境下使用的方法
                   n_clusters=5, method='ward'):
        pos_logits, neg_logits = self.forward(      #获取正样本和负样本的逻辑回归（logits）输出
            seq, pos, neg, timestamps, n_clusters=n_clusters, method=method)
        loss = self._get_loss(pos_logits, neg_logits)       #计算正负样本逻辑回归输出的损失
        if train:
            self._optimizer.zero_grad() #使用优化器清除模型参数的梯度
            loss.backward()     #对损失进行反向传播，计算参数的梯度
            res = [d.grad for d in self.parameters() if #检查模型参数的梯度是否有 NaN 值。如果有，打印相关信息并返回
                   d.grad is not None and torch.any(torch.isnan(d.grad))]
            if len(res):
                print("Nan grad encountered", res)
                return pos_logits, neg_logits, loss
            torch.nn.utils.clip_grad_norm_(self.parameters(), 10.0)#对梯度进行裁剪，以防梯度爆炸
            self._optimizer.step()#对模型参数进行更新
        return pos_logits, neg_logits, loss #返回正样本逻辑回归输出、负样本逻辑回归输出和损失

    def clustered_inference(self, seq, timestamps, pos, neg,#用于聚类推理的方法 推理模式（非训练下执行）
                            pos_enc=None,
                            n_clusters=10,
                            selection: str ='last',
                            method='ward'):
        return self.supervised(seq, timestamps, pos, neg,
                               train=False, n_clusters=n_clusters,
                               method=method)


class ExpDecayWeightModel(WeightedPinnerSagePlus): #实现一个具有指数衰减权重的模型
    def __init__(self, hps):#超参数
        super(ExpDecayWeightModel, self).__init__(hps)
        self._sims = WeightedMaxAggregator()#创建了一个WeightedMaxAggregator实例，并将其赋值给成员变量_sims
        self._lambda = hps['lambda']#用于指数衰减计算的系数
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def _compute_weights( #cluster_emb（聚类嵌入），timestamps（时间戳），和cluster_mask（聚类掩码）。
            self, cluster_emb, timestamps, cluster_mask):
        timestamps *= 60 * 60 * 24  # back to seconds将时间戳从天转换为秒
        time_itvls = timestamps[:, -1:, :] - timestamps # 计算时间间隔，即最后一个时间戳与所有其他时间戳的差
        time_itvls *= -self._lambda #将时间间隔乘以负的_lambda值，准备应用指数函数。
        exponentials = torch.exp(time_itvls).repeat(1, 1, cluster_mask.shape[-1])#对时间间隔应用指数函数，并将结果沿特定维度重复，以便与cluster_mask对齐。
        weights = torch.sum(exponentials * cluster_mask, dim=-1, keepdims=True)#计算加权和，这将是聚类的加权，结果是一个权重张量。
        return weights
    #ExpDecayWeightModel通过使用指数衰减函数来计算时间戳之间的权重
    #这种方式可能是用于捕捉数据中的时间动态或对较新的数据给予更大的权重


