import math, os
from dataclasses import dataclass
from typing import Union

import torch
import scipy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree, remove_self_loops
from einops import rearrange, repeat, einsum
from selective_modeling.modules.graph_selective_modeling import Mamba
from torch.autograd import Function
import pywt

BIG_CONSTANT = 1e8


def shannon_entropy(matrix):
    matrix = F.softmax(matrix, dim=1)
    log_matrix = torch.log2(matrix)
    elementwise_product = matrix * log_matrix
    entropy = -torch.sum(elementwise_product, dim=1)
    # avg_entropy = torch.mean(entropy)

    return entropy.squeeze()


def kl_divergence(x, y):
    x_log = F.log_softmax(x, dim=1)
    y = F.softmax(y, dim=1)
    kl = nn.KLDivLoss(reduction='none')
    out = kl(x_log, y)
    out = torch.sum(out, dim=1)
    return out


def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
    nb_full_blocks = int(m / d)
    block_list = []
    current_seed = seed
    for _ in range(nb_full_blocks):
        torch.manual_seed(current_seed)
        if struct_mode:
            q = create_products_of_givens_rotations(d, current_seed)
        else:
            unstructured_block = torch.randn((d, d))
            q, _ = torch.qr(unstructured_block)
            q = torch.t(q)
        block_list.append(q)
        current_seed += 1
    remaining_rows = m - nb_full_blocks * d
    if remaining_rows > 0:
        torch.manual_seed(current_seed)
        if struct_mode:
            q = create_products_of_givens_rotations(d, current_seed)
        else:
            unstructured_block = torch.randn((d, d))
            q, _ = torch.qr(unstructured_block)
            q = torch.t(q)
        block_list.append(q[0:remaining_rows])
    final_matrix = torch.vstack(block_list)

    current_seed += 1
    torch.manual_seed(current_seed)
    if scaling == 0:
        multiplier = torch.norm(torch.randn((m, d)), dim=1)
    elif scaling == 1:
        multiplier = torch.sqrt(torch.tensor(float(d))) * torch.ones(m)
    else:
        raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling)

    return torch.matmul(torch.diag(multiplier), final_matrix)


def create_products_of_givens_rotations(dim, seed):
    nb_givens_rotations = dim * int(math.ceil(math.log(float(dim))))
    q = np.eye(dim, dim)
    np.random.seed(seed)
    for _ in range(nb_givens_rotations):
        random_angle = math.pi * np.random.uniform()
        random_indices = np.random.choice(dim, 2)
        index_i = min(random_indices[0], random_indices[1])
        index_j = max(random_indices[0], random_indices[1])
        slice_i = q[index_i]
        slice_j = q[index_j]
        new_slice_i = math.cos(random_angle) * slice_i + math.cos(random_angle) * slice_j
        new_slice_j = -math.sin(random_angle) * slice_i + math.cos(random_angle) * slice_j
        q[index_i] = new_slice_i
        q[index_j] = new_slice_j
    return torch.tensor(q, dtype=torch.float32)


def relu_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.001):
    del is_query
    if projection_matrix is None:
        return F.relu(data) + numerical_stabilizer
    else:
        ratio = 1.0 / torch.sqrt(
            torch.tensor(projection_matrix.shape[0], torch.float32)
        )
        data_dash = ratio * torch.einsum("bnhd,md->bnhm", data, projection_matrix)
        return F.relu(data_dash) + numerical_stabilizer


def softmax_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.000001):
    data_normalizer = 1.0 / torch.sqrt(torch.sqrt(torch.tensor(data.shape[-1], dtype=torch.float32)))
    data = data_normalizer * data
    ratio = 1.0 / torch.sqrt(torch.tensor(projection_matrix.shape[0], dtype=torch.float32))
    data_dash = torch.einsum("bnhd,md->bnhm", data, projection_matrix)
    diag_data = torch.square(data)
    diag_data = torch.sum(diag_data, dim=len(data.shape) - 1)
    diag_data = diag_data / 2.0
    diag_data = torch.unsqueeze(diag_data, dim=len(data.shape) - 1)
    last_dims_t = len(data_dash.shape) - 1
    attention_dims_t = len(data_dash.shape) - 3
    if is_query:
        data_dash = ratio * (
                torch.exp(data_dash - diag_data - torch.max(data_dash, dim=last_dims_t, keepdim=True)[
                    0]) + numerical_stabilizer
        )
    else:
        data_dash = ratio * (
                torch.exp(data_dash - diag_data - torch.max(torch.max(data_dash, dim=last_dims_t, keepdim=True)[0],
                                                            dim=attention_dims_t, keepdim=True)[
                    0]) + numerical_stabilizer
        )
    return data_dash


def numerator(qs, ks, vs):
    kvs = torch.einsum("nbhm,nbhd->bhmd", ks, vs)  # kvs refers to U_k in the paper
    return torch.einsum("nbhm,bhmd->nbhd", qs, kvs)


def denominator(qs, ks):
    all_ones = torch.ones([ks.shape[0]]).to(qs.device)
    ks_sum = torch.einsum("nbhm,n->bhm", ks, all_ones)  # ks_sum refers to O_k in the paper
    return torch.einsum("nbhm,bhm->nbh", qs, ks_sum)


def numerator_gumbel(qs, ks, vs):
    kvs = torch.einsum("nbhkm,nbhd->bhkmd", ks, vs)  # kvs refers to U_k in the paper
    return torch.einsum("nbhm,bhkmd->nbhkd", qs, kvs)


def denominator_gumbel(qs, ks):
    all_ones = torch.ones([ks.shape[0]]).to(qs.device)
    ks_sum = torch.einsum("nbhkm,n->bhkm", ks, all_ones)  # ks_sum refers to O_k in the paper
    return torch.einsum("nbhm,bhkm->nbhk", qs, ks_sum)


def kernelized_softmax(query, key, value, kernel_transformation, projection_matrix=None, edge_index=None, tau=0.25,
                       return_weight=True):
    query = query / math.sqrt(tau)
    key = key / math.sqrt(tau)
    query_prime = kernel_transformation(query, True, projection_matrix)  # [B, N, H, M]
    key_prime = kernel_transformation(key, False, projection_matrix)  # [B, N, H, M]
    query_prime = query_prime.permute(1, 0, 2, 3)  # [N, B, H, M]
    key_prime = key_prime.permute(1, 0, 2, 3)  # [N, B, H, M]
    value = value.permute(1, 0, 2, 3)  # [N, B, H, D]

    # compute updated node emb, this step requires O(N)
    z_num = numerator(query_prime, key_prime, value)
    z_den = denominator(query_prime, key_prime)

    z_num = z_num.permute(1, 0, 2, 3)  # [B, N, H, D]
    z_den = z_den.permute(1, 0, 2)
    z_den = torch.unsqueeze(z_den, len(z_den.shape))
    z_output = z_num / z_den  # [B, N, H, D]

    if return_weight:  # query edge prob for computing edge-level reg loss, this step requires O(E)
        start, end = edge_index
        query_end, key_start = query_prime[end], key_prime[start]  # [E, B, H, M]
        edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, key_start)  # [E, B, H]
        edge_attn_num = edge_attn_num.permute(1, 0, 2)  # [B, E, H]
        attn_normalizer = denominator(query_prime, key_prime)  # [N, B, H]
        edge_attn_dem = attn_normalizer[end]  # [E, B, H]
        edge_attn_dem = edge_attn_dem.permute(1, 0, 2)  # [B, E, H]
        A_weight = edge_attn_num / edge_attn_dem  # [B, E, H]

        return z_output, A_weight

    else:
        return z_output, 0


def kernelized_gumbel_softmax(query, key, value, kernel_transformation, projection_matrix=None, edge_index=None,
                              K=10, tau=0.25, return_weight=True):
    '''
    fast computation of all-pair attentive aggregation with linear complexity
    input: query/key/value [B, N, H, D]
    return: updated node emb, attention weight (for computing edge loss)
    B = graph number (always equal to 1 in Node Classification), N = node number, H = head number,
    M = random feature dimension, D = hidden size, K = number of Gumbel sampling
    '''
    query = query / math.sqrt(tau)
    key = key / math.sqrt(tau)
    query_prime = kernel_transformation(query, True, projection_matrix)  # [B, N, H, M]
    key_prime = kernel_transformation(key, False, projection_matrix)  # [B, N, H, M]
    query_prime = query_prime.permute(1, 0, 2, 3)  # [N, B, H, M]
    key_prime = key_prime.permute(1, 0, 2, 3)  # [N, B, H, M]
    value = value.permute(1, 0, 2, 3)  # [N, B, H, D]

    # compute updated node emb, this step requires O(N)
    gumbels = (
                  -torch.empty(key_prime.shape[:-1] + (K,),
                               memory_format=torch.legacy_contiguous_format).exponential_().log()
              ).to(query.device) / tau  # [N, B, H, K]
    key_t_gumbel = key_prime.unsqueeze(3) * gumbels.exp().unsqueeze(4)  # [N, B, H, K, M]
    z_num = numerator_gumbel(query_prime, key_t_gumbel, value)  # [N, B, H, K, D]
    z_den = denominator_gumbel(query_prime, key_t_gumbel)  # [N, B, H, K]

    z_num = z_num.permute(1, 0, 2, 3, 4)  # [B, N, H, K, D]
    z_den = z_den.permute(1, 0, 2, 3)  # [B, N, H, K]
    z_den = torch.unsqueeze(z_den, len(z_den.shape))
    z_output = torch.mean(z_num / z_den, dim=3)  # [B, N, H, D]

    if return_weight:  # query edge prob for computing edge-level reg loss, this step requires O(E)
        start, end = edge_index
        query_end, key_start = query_prime[end], key_prime[start]  # [E, B, H, M]
        edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, key_start)  # [E, B, H]
        edge_attn_num = edge_attn_num.permute(1, 0, 2)  # [B, E, H]
        attn_normalizer = denominator(query_prime, key_prime)  # [N, B, H]
        edge_attn_dem = attn_normalizer[end]  # [E, B, H]
        edge_attn_dem = edge_attn_dem.permute(1, 0, 2)  # [B, E, H]
        A_weight = edge_attn_num / edge_attn_dem  # [B, E, H]

        return z_output, A_weight

    else:
        return z_output, 0


def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'):
    '''
    compute updated result by the relational bias of input adjacency
    the implementation is similar to the Graph Convolution Network with a (shared) scalar weight for each edge
    '''
    row, col = edge_index
    d_in = degree(col, x.shape[1]).float()
    d_norm_in = (1. / d_in[col]).sqrt()
    d_out = degree(row, x.shape[1]).float()
    d_norm_out = (1. / d_out[row]).sqrt()
    conv_output = []
    for i in range(x.shape[2]):
        if trans == 'sigmoid':
            b_i = F.sigmoid(b[i])
        elif trans == 'identity':
            b_i = b[i]
        else:
            raise NotImplementedError
        value = torch.ones_like(row) * b_i * d_norm_in * d_norm_out
        adj_i = SparseTensor(row=col, col=row, value=value, sparse_sizes=(x.shape[1], x.shape[1]))
        conv_output.append(matmul(adj_i, x[:, :, i]))  # [B, N, D]
    conv_output = torch.stack(conv_output, dim=2)  # [B, N, H, D]
    return conv_output


class KernelizedConv(nn.Module):
    '''
    one layer of NodeFormer that attentive aggregates all nodes over a latent graph
    return: node embeddings for next layer, edge loss at this layer
    '''

    def __init__(self, in_channels, out_channels, num_heads, kernel_transformation=softmax_kernel_transformation,
                 projection_matrix_type='a',
                 nb_random_features=10, use_gumbel=True, nb_gumbel_sample=10, rb_order=0, rb_trans='sigmoid',
                 use_edge_loss=True):
        super(KernelizedConv, self).__init__()
        self.Wk = nn.Linear(in_channels, out_channels * num_heads)
        self.Wq = nn.Linear(in_channels, out_channels * num_heads)
        self.Wv = nn.Linear(in_channels, out_channels * num_heads)
        self.Wo = nn.Linear(out_channels * num_heads, out_channels)
        if rb_order >= 1:
            self.b = torch.nn.Parameter(torch.FloatTensor(rb_order, num_heads), requires_grad=True)

        self.out_channels = out_channels
        self.num_heads = num_heads
        self.kernel_transformation = kernel_transformation
        self.projection_matrix_type = projection_matrix_type
        self.nb_random_features = nb_random_features
        self.use_gumbel = use_gumbel
        self.nb_gumbel_sample = nb_gumbel_sample
        self.rb_order = rb_order
        self.rb_trans = rb_trans
        self.use_edge_loss = use_edge_loss

    def reset_parameters(self):
        self.Wk.reset_parameters()
        self.Wq.reset_parameters()
        self.Wv.reset_parameters()
        self.Wo.reset_parameters()
        if self.rb_order >= 1:
            if self.rb_trans == 'sigmoid':
                torch.nn.init.constant_(self.b, 0.1)
            elif self.rb_trans == 'identity':
                torch.nn.init.constant_(self.b, 1.0)

    def forward(self, z, adjs, tau):
        B, N = z.size(0), z.size(1)
        query = self.Wq(z).reshape(-1, N, self.num_heads, self.out_channels)
        key = self.Wk(z).reshape(-1, N, self.num_heads, self.out_channels)
        value = self.Wv(z).reshape(-1, N, self.num_heads, self.out_channels)

        if self.projection_matrix_type is None:
            projection_matrix = None
        else:
            dim = query.shape[-1]
            seed = torch.ceil(torch.abs(torch.sum(query) * BIG_CONSTANT)).to(torch.int32)

            ensemble_size = 5
            dim = query.shape[-1]
            self.nb_random_features = 32
            ensemble_q, ensemble_k, ensemble_v = [], [], []
            for i in range(ensemble_size):
                proj_matrix = create_projection_matrix(self.nb_random_features, dim, seed=seed + i).to(query.device)
                q_i = self.kernel_transformation(query, True, proj_matrix)  # [B, N, H, M]
                k_i = self.kernel_transformation(key, False, proj_matrix)
                v_i = value
                ensemble_q.append(q_i)
                ensemble_k.append(k_i)
                ensemble_v.append(v_i)

            q_stack = torch.stack(ensemble_q, dim=0)  # [K, B, N, H, M]
            q_mean = q_stack.mean(dim=0)  # [B, N, H, M]
            q_centered = q_stack - q_mean  # [K, B, N, H, M]
            var_q = (q_centered ** 2).mean(dim=0) + 1e-6  # [B, N, H, M]
            v_stack = torch.stack(ensemble_v, dim=0)  # [K, B, N, H, D]
            cross_cov_qv = torch.einsum("kbnhm,kbnhd->bnhmd", q_centered, v_stack) / ensemble_size  # [B, N, H, M, D]
            K_gain = cross_cov_qv / var_q.unsqueeze(-1)  # [B, N, H, M, D]
            z_next = torch.einsum("bnhmd,bnhm->bnhd", K_gain, q_mean)  # [B, N, H, D]
        for i in range(self.rb_order):
            z_next += add_conv_relational_bias(value, adjs[i], self.b[i], self.rb_trans)

        z_next = self.Wo(z_next.flatten(-2, -1))
        link_loss = 0
        return z_next, link_loss

class KernelizedMP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, num_heads=4, dropout=0.0,
                 kernel_transformation=softmax_kernel_transformation, nb_random_features=30, use_bn=True,
                 use_gumbel=True,
                 use_residual=True, use_act=False, use_jk=False, nb_gumbel_sample=10, rb_order=0, rb_trans='sigmoid',
                 use_edge_loss=True):
        super(KernelizedMP, self).__init__()

        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(in_channels, hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.LayerNorm(hidden_channels))
        for i in range(num_layers):
            self.convs.append(
                KernelizedConv(hidden_channels, hidden_channels, num_heads=num_heads,
                               kernel_transformation=kernel_transformation,
                               nb_random_features=nb_random_features, use_gumbel=use_gumbel,
                               nb_gumbel_sample=nb_gumbel_sample,
                               rb_order=rb_order, rb_trans=rb_trans, use_edge_loss=use_edge_loss))
            self.bns.append(nn.LayerNorm(hidden_channels))

        if use_jk:
            self.fcs.append(nn.Linear(hidden_channels * num_layers + hidden_channels, out_channels))
        else:
            self.fcs.append(nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout
        self.activation = F.elu
        self.use_bn = use_bn
        self.use_residual = use_residual
        self.use_act = use_act
        self.use_jk = use_jk
        self.use_edge_loss = use_edge_loss

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, x, adjs, tau=1.0):
        x = x.unsqueeze(0)
        layer_ = []
        link_loss_ = []
        z = self.fcs[0](x)
        if self.use_bn:
            z = self.bns[0](z)
        z = self.activation(z)
        z = F.dropout(z, p=self.dropout, training=self.training)
        layer_.append(z)

        for i, conv in enumerate(self.convs):
            if self.use_edge_loss:
                z, link_loss = conv(z, adjs, tau)
                link_loss_.append(link_loss)
            else:
                z = conv(z, adjs, tau)
            if self.use_residual:
                z += layer_[i]
            if self.use_bn:
                z = self.bns[i + 1](z)
            if self.use_act:
                z = self.activation(z)
            z = F.dropout(z, p=self.dropout, training=self.training)
            layer_.append(z)

        if self.use_jk:  # use jk connection for each layer
            z = torch.cat(layer_, dim=-1)

        x_out = self.fcs[-1](z).squeeze(0)

        if self.use_edge_loss:
            return x_out, link_loss_
        else:
            return x_out


class MergeLayer(nn.Module):

    def __init__(self, input_dim1: int, input_dim2: int, hidden_dim: int, output_dim: int):
        """
        Merge Layer to merge two inputs via: input_dim1 + input_dim2 -> hidden_dim -> output_dim.
        :param input_dim1: int, dimension of first input
        :param input_dim2: int, dimension of the second input
        :param hidden_dim: int, hidden dimension
        :param output_dim: int, dimension of the output
        """
        super().__init__()
        self.fc1 = nn.Linear(input_dim1 + input_dim2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.act = nn.ReLU()

    def forward(self, input_1: torch.Tensor, input_2: torch.Tensor):
        """
        merge and project the inputs
        :param input_1: Tensor, shape (*, input_dim1)
        :param input_2: Tensor, shape (*, input_dim2)
        :return:
        """
        # Tensor, shape (*, input_dim1 + input_dim2)
        x = torch.cat([input_1, input_2], dim=1)
        # Tensor, shape (*, output_dim)
        h = self.fc2(self.act(self.fc1(x)))
        return h

class DWTFunction_1D(Function):
    @staticmethod
    def forward(ctx, input, matrix_Low, matrix_High):
        ctx.save_for_backward(matrix_Low, matrix_High)
        batch_size, channels, length = input.size()

        input_reshaped = input.reshape(-1, length)
        L = torch.matmul(input_reshaped, matrix_Low.t())
        H = torch.matmul(input_reshaped, matrix_High.t())

        L = L.view(batch_size, channels, -1)
        H = H.view(batch_size, channels, -1)
        return L, H

    @staticmethod
    def backward(ctx, grad_L, grad_H):
        matrix_Low, matrix_High = ctx.saved_tensors

        grad_L = grad_L.contiguous().view(-1, grad_L.size(-1))
        grad_H = grad_H.contiguous().view(-1, grad_H.size(-1))

        grad_input = torch.matmul(grad_L, matrix_Low) + torch.matmul(grad_H, matrix_High)
        return grad_input.view(matrix_Low.size(1), -1).unsqueeze(0).permute(0,2,1), None, None


class DWT_1D(nn.Module):
    def __init__(self, wavename):
        super().__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.dec_lo
        self.band_high = wavelet.dec_hi
        self.band_length = len(self.band_low)
        self.pad_size = self.band_length // 2

    def get_matrix(self, input_length):
        # output_len = (input_length + 1) // 2
        output_len = (input_length)
        matrix_low = np.zeros((output_len, input_length))
        matrix_high = np.zeros((output_len, input_length))

        for i in range(output_len):
            start = i * 2 - self.pad_size
            end = start + self.band_length
            if start < 0:
                matrix_low[i, :start + self.band_length] = self.band_low[-start:]
            elif end > input_length:
                matrix_low[i, start:] = self.band_low[:input_length - start]
            else:
                matrix_low[i, start:end] = self.band_low

            start = i * 2 - self.pad_size
            end = start + self.band_length
            if start < 0:
                matrix_high[i, :start + self.band_length] = self.band_high[-start:]
            elif end > input_length:
                matrix_high[i, start:] = self.band_high[:input_length - start]
            else:
                matrix_high[i, start:end] = self.band_high

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.matrix_low = torch.tensor(matrix_low, dtype=torch.float32, device=device)
        self.matrix_high = torch.tensor(matrix_high, dtype=torch.float32, device=device)

    def forward(self, input):
        _, _, T = input.size()
        self.get_matrix(T)
        return DWTFunction_1D.apply(input, self.matrix_low, self.matrix_high)


class WPL(nn.Module):
    def __init__(self, wavename='haar'):
        super().__init__()
        self.dwt = DWT_1D(wavename)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        N, T, C = x.size()

        x = x.permute(2, 0, 1)  # [C, N, T]
        L, H = self.dwt(x)  # [C, N, L], [C, N, H]
        H_up = torch.nn.functional.interpolate(
            H,
            size=L.size(-1),
            mode='linear',
            align_corners=False
        )
        att = self.softmax(H_up)
        output = L * att

        return output.permute(1, 2, 0)  # [N, L, C]


class REDGSL(nn.Module):
    def __init__(self, n_feats, hidden_channels, node_channels, mamba_features, beta2, lamda_1=0.5, num_layers=2,
                 num_heads=4, dropout=0.0, nb_random_features=30, use_bn=True, use_gumbel=True,
                 use_residual=True, use_act=False, use_jk=False, nb_gumbel_sample=10, rb_order=0, rb_trans='sigmoid',
                 tau=1.0):
        super(REDGSL, self).__init__()
        self.tau = tau
        self.node_channels = node_channels
        self.lamda_1 = lamda_1
        self.beta2 = beta2
        self.KernelizedMP = KernelizedMP(in_channels=n_feats, hidden_channels=hidden_channels,
                                         out_channels=node_channels,
                                         num_layers=num_layers, dropout=dropout,
                                         num_heads=num_heads, use_bn=use_bn, nb_random_features=nb_random_features,
                                         use_gumbel=use_gumbel, use_residual=use_residual, use_act=use_act,
                                         use_jk=use_jk,
                                         nb_gumbel_sample=nb_gumbel_sample, rb_order=rb_order, rb_trans=rb_trans)
        self.DGSM = Mamba(d_model=mamba_features,  # Model dimension d_model
                          d_state=16,  # SSM state expansion factor
                          d_conv=4,  # Local convolution width
                          expand=1, )

        self.LayerNorm = nn.LayerNorm(mamba_features, eps=1e-10)
        self.WPL = WPL()

    def forward(self, node_embeddings, adj_matrices, timestamp):
        mp_embs = []
        edge_losses = []
        num_nodes = node_embeddings[0].shape[0] # num_nodes : 13095
        selective_modeling_adjs = []

        for t in range(timestamp):
            # Intra- and Inter- Kernelized message passing
            local_emb, loss = self.KernelizedMP(node_embeddings[t], adj_matrices[t], self.tau)
            edge_losses.append(loss)
            if t == 0:
                total_emb = local_emb
            else:
                time_emb, _ = self.KernelizedMP(mp_embs[t - 1], adj_matrices[t - 1], self.tau)
                total_emb = time_emb + local_emb

            mp_embs.append(total_emb)

            # Graph Selective Modeling adj matrices
            values = torch.ones(adj_matrices[t][0].shape[1]).to(adj_matrices[t][0].device)

            size = (num_nodes, num_nodes)

            temp_adj = torch.sparse_coo_tensor(adj_matrices[t][0], values, size)

            selective_modeling_adjs.append(temp_adj)

        # Dynamic Graph Selective Modeling
        z_mp = torch.stack(mp_embs)
        graph_ssm_input = z_mp.mean(dim=-1, keepdim=True)
        graph_ssm_input = self.WPL(graph_ssm_input.permute(1, 0, 2)).permute(1, 0, 2) + graph_ssm_input
        graph_ssm_input = graph_ssm_input.permute(2, 0, 1)
        graph_ssm_output = self.DGSM(graph_ssm_input, selective_modeling_adjs)
        graph_ssm_output = graph_ssm_output.squeeze()
        graph_ssm_output = self.LayerNorm(graph_ssm_output)

        # Calculate kl_loss and entropy
        kl_loss = kl_divergence(graph_ssm_output, graph_ssm_input.squeeze())
        entropy = shannon_entropy(graph_ssm_output)

        z_seq = graph_ssm_output.unsqueeze(-1)
        z_seq = z_seq.repeat_interleave(self.node_channels, dim=-1)
        embs = z_mp + self.lamda_1 * z_seq

        inter_loss = entropy + self.beta2 * kl_loss # torch.Size([15])
        return embs, edge_losses, inter_loss