import json
import torch
import socket
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from tensorboardX import SummaryWriter
import numpy as np
from modules.MLP import DynamicMLP
__all__ = ["DGRModel"]

class DGRModel(nn.Module):
    def __init__(self, device=None, seed=None):
        super(DGRModel, self).__init__()
        if device is None:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.relative_threshold = 0.5
        self.metric_heads = 1
        self.attention_heads = 1
        self.pool_layers = 1
        self.layer_shared = True
        self.pool_length = 30
        self.layer_keeps = torch.tensor(1, dtype=torch.float32, requires_grad=False).to(device)
        self.hidden_size = 40
        # self.item_lookup = nn.Embedding(10000, 32).to(self.device)
    def _build_seq_graph(self, input_embedding):
        self.item_history_embedding = input_embedding
        X = torch.tensor(self.item_history_embedding, dtype=torch.float, device=self.device)
        # self.item_lookup = nn.Embedding(input_embedding.shape[0], input_embedding.shape[1]).to(self.device)
        self.item_lookup = nn.Embedding(X.shape[0], X.shape[1]).to(self.device)
        items_tensor = torch.mean(X, dim=-1).long()
        items_tensor = torch.clamp(items_tensor, 0, X.shape[0] - 1)
        item_embedding = self.item_lookup(items_tensor)
        self.target_item_embedding = item_embedding.clone().detach()
        instance_cnt = X.shape[0]
        self.mask = torch.triu(torch.ones(instance_cnt , instance_cnt ), diagonal=0).to(self.device).float()
        self.real_sequence_length = torch.sum(self.mask, dim=1)
        S = []
        self.weighted_layers = nn.ModuleList(
            [nn.Linear(1, X.size(-1), bias=False).to(self.device) for _ in range(self.metric_heads)])
        for i in range(self.metric_heads):
            weighted_tensor = self.weighted_layers[i](torch.ones(1, 1, device=self.device))
            X_fts = X * weighted_tensor
            X_fts = F.normalize(X_fts, p=2, dim=1)
            S_one = torch.matmul(X_fts, X_fts.T)

            S_min = torch.min(S_one, dim=1, keepdim=True)[0]
            S_max = torch.max(S_one, dim=-1, keepdim=True)[0]
            S_one = (S_one - S_min) / (S_max - S_min)
            S.append(S_one)
        S = torch.mean(torch.stack(S, dim=0), dim=0)

        S = S * self.mask * self.mask.transpose(0, 1)

        S_flatten = S.view(-1)
        sorted_S_flatten, _ = torch.sort(S_flatten, dim=-1, descending=True)
        num_edges = torch.count_nonzero(S, dim=(0, 1)).float()
        to_keep_edge = torch.ceil(num_edges * self.relative_threshold).int()



        threshold_score = sorted_S_flatten[to_keep_edge]
        A = (S > threshold_score).float()


        del S, S_flatten, sorted_S_flatten, to_keep_edge, threshold_score,num_edges,items_tensor,item_embedding
        torch.cuda.empty_cache()

        for layer in range(self.pool_layers):
            reuse = False if layer == 0 else True
            Xc, A = self._interest_fusion_extraction(X, A, layer, reuse)

        return Xc

    def _interest_fusion_extraction(self, X, A, layer, reuse):
        A_bool = (A > 0).type(A.dtype)
        A_bool = A_bool * (torch.ones([A.shape[0], A.shape[0]], device=self.device) - torch.eye(A.shape[0],
                                                                                                device=self.device)) + torch.eye(
            A.shape[0], device=self.device)
        D = torch.sum(A_bool, dim=-1)
        D = torch.sqrt(D)[:, None] + 1e-7
        A = (A_bool / D) / D.transpose(0, 1)
        X_q = torch.matmul(A, torch.matmul(A, X))



        Xc = []
        for i in range(self.attention_heads):
            if not self.layer_shared:
                _, f_1 = self._attention_fcn(X_q, X, 'f1_layer_' + str(layer) + '_' + str(i), False, return_alpha=True)
                _, f_2 = self._attention_fcn(self.target_item_embedding, X, 'f2_layer_' + str(layer) + '_' + str(i),
                                             False, return_alpha=True)
            if self.layer_shared:
                _, f_1 = self._attention_fcn(X_q, X, 'f1_shared' + '_' + str(i), reuse, return_alpha=True)
                _, f_2 = self._attention_fcn(self.target_item_embedding, X, 'f2_shared' + '_' + str(i), reuse,
                                             return_alpha=True)

            E = A_bool * f_1 + A_bool * f_2
            E = F.leaky_relu(E)
            boolean_mask = A_bool == torch.ones_like(A_bool)
            mask_paddings = torch.ones_like(E) * (-(2 ** 32) + 1)
            E = F.softmax(
                torch.where(boolean_mask, E, mask_paddings),
                dim=-1
            )


            Xc_one = torch.matmul(E, X)
            dense_layer = torch.nn.Linear(in_features=Xc_one.shape[1], out_features=X.shape[1], bias=False).to(
                self.device)  # 确保输出大小与输入相同
            Xc_one = dense_layer(Xc_one)
            Xc_one += X  # 保留原始信息
            Xc.append(F.leaky_relu(Xc_one))

        Xc = torch.mean(torch.stack(Xc, 0), 0)  # 聚合多个注意力头的结果



        return Xc, A  # 返回更新后的节点特征表示和邻接矩阵

    def _attention_fcn(self, query, key_value, name, reuse, return_alpha=False):
        query_size = query.shape[-1]
        boolean_mask = (self.mask == torch.ones_like(self.mask))
        self.attention_mat = nn.Parameter(torch.empty(key_value.shape[-1], query_size)).to(self.device)
        init.trunc_normal_(self.attention_mat, mean=0.0, std=1.0, a=-2, b=2)

        att_inputs = torch.tensordot(key_value, self.attention_mat, dims=([-1], [0]))

        if query.dim() != att_inputs.dim():
            query_repeated = query.repeat(1, att_inputs.shape[1])
            queries = query_repeated.view_as(att_inputs)
        else:
            queries = query

        last_hidden_nn_layer = torch.cat(
            [att_inputs, queries, att_inputs - queries, att_inputs * queries], dim=-1
        )
        input_size = last_hidden_nn_layer.shape[-1]
        output_size = 1
        layer_sizes = [80, 40]
        activations = ['relu', 'tanh']
        use_bn = True
        dropout_rate = 0.5
        model = DynamicMLP(input_size, layer_sizes, output_size, activations, use_bn, dropout_rate).to(self.device)

        att_fnc_output = model(last_hidden_nn_layer)
        att_fnc_output = att_fnc_output.squeeze(-1)
        mask_paddings = torch.ones_like(att_fnc_output) * (-(2 ** 32) + 1)
        att_weights = torch.nn.functional.softmax(
            torch.where(boolean_mask, att_fnc_output, mask_paddings), dim=-1
        )

        output = key_value * att_weights.unsqueeze(-1)
        if not return_alpha:
            return output
        else:
            return output, att_weights

   
