import torch
import numpy as np
import torch.nn as nn
from utils.utils import NeighborSampler
import math
from models.modules import TimeEncoder
import geoopt
from geoopt.manifolds.stereographic.math import project
from geoopt.manifolds.stereographic import StereographicExact
from geoopt import ManifoldTensor
from geoopt import ManifoldParameter
import torch.nn.functional as F

class RandomProjectionModule(nn.Module):
    def __init__(self, node_num: int, edge_num: int, dim_factor: int, num_layer: int, time_decay_weight: float,
                 device: str, use_matrix: bool, beginning_time: np.float64, not_scale: bool, enforce_dim: int):
        """
        This model maintains a series of temporal walk matrices $A_^(0)(t),A_^(1)(t),...,A^(k)(t)$ through
        random feature propagation, and extract the pairwise features from the obtained random projections.
        :param node_num: int, the number of nodes
        :param edge_num: int, the number of edges
        :param dim_factor: int, the parameter to control the dimension of random projections. Specifically, the
                           dimension of the random projections is set to be dim_factor * log(2*edge_num)
        :param num_layer: int, the max hop of the maintained temporal walk matrices
        :param time_decay_weight: float, the time decay weight (lambda of the original paper)
        :param device: str, torch device
        :param use_matrix: bool, if True, explicitly maintain the temporal walk matrices
        :param beginning_time: np.float64, the earliest time in the given temporal graph
        :param not_scale: bool, if True, the inner product of nodes' random projections will not be scaled
        :param enforce_dim: int, if not -1, explicitly set the dimension of random projections to enforce_dim
        """
        super(RandomProjectionModule, self).__init__()
        self.node_num = node_num
        self.edge_num = edge_num
        if enforce_dim != -1:
            self.dim = enforce_dim
        else:
            self.dim = min(int(math.log(self.edge_num * 2)) * dim_factor, node_num)
        self.num_layer = num_layer
        self.time_decay_weight = time_decay_weight
        self.begging_time = nn.Parameter(torch.tensor(beginning_time), requires_grad=False)
        self.now_time = nn.Parameter(torch.tensor(beginning_time), requires_grad=False)
        self.device = device
        self.random_projections = nn.ParameterList()
        self.use_matrix = use_matrix
        self.node_feature_dim = 128
        self.not_scale = not_scale
        # if use_matrix = True, directly store the temporal walk matrices
        if self.use_matrix:
            self.dim = self.node_num
            for i in range(self.num_layer + 1):
                if i == 0:
                    self.random_projections.append(
                        nn.Parameter(torch.eye(self.node_num), requires_grad=False))
                else:
                    self.random_projections.append(
                        nn.Parameter(torch.zeros_like(self.random_projections[i - 1]), requires_grad=False))
        # otherwise, store the random projection of the temporal walk matrices
        else:
            for i in range(self.num_layer + 1):
                if i == 0:
                    self.random_projections.append(
                        nn.Parameter(torch.normal(0, 1 / math.sqrt(self.dim), (self.node_num, self.dim)),
                                     requires_grad=False))
                else:
                    self.random_projections.append(
                        nn.Parameter(torch.zeros_like(self.random_projections[i - 1]), requires_grad=False))
        self.pair_wise_feature_dim = (2 * self.num_layer + 2) ** 2
        self.mlp = nn.Sequential(nn.Linear(self.pair_wise_feature_dim, self.pair_wise_feature_dim * 4), nn.ReLU(),
                                 nn.Linear(self.pair_wise_feature_dim * 4, self.pair_wise_feature_dim))

    def update(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, node_interact_times: np.ndarray):
        """
        updating the temporal walk matrices after observing a batch of interactions.
        :param src_node_ids: np.ndarray, shape (batch,),source node ids
        :param dst_node_ids: np.ndarray, shape (batch,), destination node ids
        :param node_interact_times: np.ndarray, shape (batch,), timestamps of interactions
        """
        src_node_ids = torch.from_numpy(src_node_ids).to(self.device)
        dst_node_ids = torch.from_numpy(dst_node_ids).to(self.device)
        next_time = node_interact_times[-1]
        node_interact_times = torch.from_numpy(node_interact_times).to(dtype=torch.float, device=self.device)
        time_weight = torch.exp(-self.time_decay_weight * (next_time - node_interact_times))[:, None]

        # updating for the current timestamp being moved
        # since the current timestamp will be set to the biggest timestamp in this batch
        # so the random projections should be multiplied by the corresponding time decay weight
        for i in range(1, self.num_layer + 1):
            self.random_projections[i].data = self.random_projections[i].data * np.power(np.exp(
                -self.time_decay_weight * (next_time - self.now_time.cpu().numpy())), i)

        # updating for adding new interactions
        # we use the batch updating schema, where we first computing the influence of each interaction
        # and then aggregate them together
        for i in range(self.num_layer, 0, -1):
            src_update_messages = self.random_projections[i - 1][dst_node_ids] * time_weight
            dst_update_messages = self.random_projections[i - 1][src_node_ids] * time_weight
            self.random_projections[i].scatter_add_(dim=0, index=src_node_ids[:, None].expand(-1, self.dim),
                                                    src=src_update_messages)
            self.random_projections[i].scatter_add_(dim=0, index=dst_node_ids[:, None].expand(-1, self.dim),
                                                    src=dst_update_messages)

        # set current timestamp to the biggest timestamp in this batch
        self.now_time.data = torch.tensor(next_time, device=self.device)

    def get_random_projections(self, node_ids: np.ndarray):
        """
        get the random projections of the give node ids.
        :param node_ids: np.ndarray, shape (batch,)
        :return:
        """
        random_projections = []
        for i in range(self.num_layer + 1):
            random_projections.append(self.random_projections[i][node_ids])
        return random_projections

    def get_pair_wise_feature(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray):
        """
        get pairwise feature for given source nodes and destination nodes.
        :param src_node_ids: np.ndarray, shape (batch,)
        :param dst_node_ids: np.ndarray, shape (batch,)
        :return:
        """
        src_random_projections = torch.stack(self.get_random_projections(src_node_ids), dim=1)
        dst_random_projections = torch.stack(self.get_random_projections(dst_node_ids), dim=1)
        random_projections = torch.cat([src_random_projections, dst_random_projections], dim=1)
        random_feature = torch.matmul(random_projections, random_projections.transpose(1, 2)).reshape(
            len(src_node_ids), -1)
        if self.not_scale:
            return self.mlp(random_feature)
        else:
            random_feature[random_feature < 0] = 0
            random_feature = torch.log(random_feature + 1.0)
            return self.mlp(random_feature)

    def reset_random_projections(self):
        """
        reset the random projections
        """
        for i in range(1, self.num_layer + 1):
            nn.init.zeros_(self.random_projections[i])
        self.now_time.data = self.begging_time.clone()
        if not self.use_matrix:
            nn.init.normal_(self.random_projections[0], mean=0, std=1 / math.sqrt(self.dim))

    def backup_random_projections(self):
        """
        backup the random projections.
        :return: tuple of (now_time,random_projections)
        """
        return self.now_time.clone(), [self.random_projections[i].clone() for i in
                                       range(1, self.num_layer + 1)]

    def reload_random_projections(self, random_projections):
        """
        reload the random projections.
        :param random_projections: tuple of (now_time,random_projections)
        """
        now_time, random_projections = random_projections
        self.now_time.data = now_time.clone()
        for i in range(1, self.num_layer + 1):
            self.random_projections[i].data = random_projections[i - 1].clone()


class DyGMoCE(torch.nn.Module):
    def __init__(self, node_raw_features: np.ndarray, edge_raw_features: np.ndarray, src_geo_feature: torch.tensor, dst_geo_feature: torch.tensor,neighbor_sampler: NeighborSampler,
                 time_feat_dim: int, dropout: float, random_projections: RandomProjectionModule,
                 num_layers: int, num_neighbors: int, margin: float, init_curvs: list, device: str):
        """
        Time decay matrix Projection-based graph neural Network for temporal link prediction, named TPNet for short.
        :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: neighbor sampler
        :param time_feat_dim: int, dimension of time features (encodings)
        :param dropout: float, dropout rate
        :param random_projections: RandomProjectionModule, the projected time decay temporal walk matrices
        :param num_layers: int, number of embedding layers
        :param num_neighbors: int, number of sampled neighbors
        :param device: str, device
        """
        super(DyGMoCE, self).__init__()

        self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device)
        
        self.src_geo_feature =  src_geo_feature.to(device)
        self.dst_geo_feature =  dst_geo_feature.to(device)

        self.node_feat_dim = self.node_raw_features.shape[1]
        self.edge_feat_dim = self.edge_raw_features.shape[1]
        self.time_feat_dim = time_feat_dim
        self.dropout = dropout
        self.device = device

        # number of nodes, including the padded node
        self.num_nodes = self.node_raw_features.shape[0]

        self.random_projections = random_projections
        self.time_encoder = TimeEncoder(time_dim=time_feat_dim)
        self.geo_encoder = nn.Linear(8, out_features=self.time_feat_dim, bias=True)
        self.geo_feat_dim = 8

        # embedding module
        self.embedding_module = DyGMoCEEmbedding(node_raw_features=self.node_raw_features,
                                               edge_raw_features=self.edge_raw_features,
                                               neighbor_sampler=neighbor_sampler,
                                               time_encoder=self.time_encoder,
                                               geo_encoder = self.geo_encoder,
                                               node_feat_dim=self.node_feat_dim,
                                               edge_feat_dim=self.edge_feat_dim,
                                               time_feat_dim=self.time_feat_dim,
                                               geo_feat_dim=self.geo_feat_dim,
                                               num_layers=num_layers,
                                               num_neighbors=num_neighbors,
                                               dropout=self.dropout,
                                               random_projections=self.random_projections,
                                               src_geo_feature=self.src_geo_feature,
                                               dst_geo_feature=self.dst_geo_feature,
                                               margin=margin, 
                                               init_curvs=init_curvs,
                                               )

    def compute_src_dst_node_temporal_embeddings(self,src_node_ids: np.ndarray, dst_node_ids: np.ndarray, src_node_hyperbolicities: np.ndarray, dst_node_hyperbolicities: np.ndarray, node_interact_times: np.ndarray,
                                                  positive= True):
        """
        compute source and destination node temporal embeddings.
        :param src_node_ids: ndarray, shape (batch_size, )
        :param dst_node_ids:: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :return:
        """
        node_embeddings, regular_loss = self.embedding_module.compute_node_temporal_embeddings(
            node_ids=np.concatenate([src_node_ids, dst_node_ids]),
            src_node_ids=np.tile(src_node_ids, 2),
            dst_node_ids=np.tile(dst_node_ids, 2),
            node_interact_times=np.tile(node_interact_times, 2),
            src_node_hyperbolicities=src_node_hyperbolicities,
            dst_node_hyperbolicities=dst_node_hyperbolicities
            )
        src_node_embeddings, dst_node_embeddings = node_embeddings[:len(src_node_ids)], node_embeddings[
                                                                                        len(src_node_ids):]
        return src_node_embeddings, dst_node_embeddings, regular_loss

    def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler):
        """
        set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling).
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :return:
        """
        self.embedding_module.neighbor_sampler = neighbor_sampler
        if self.embedding_module.neighbor_sampler.sample_neighbor_strategy in ['uniform', 'time_interval_aware']:
            assert self.embedding_module.neighbor_sampler.seed is not None
            self.embedding_module.neighbor_sampler.reset_random_state()


class DyGMoCEEmbedding(nn.Module):
    def __init__(self, node_raw_features: torch.Tensor, edge_raw_features: torch.Tensor,
                 neighbor_sampler: NeighborSampler,
                 time_encoder: nn.Module, geo_encoder:nn.Module, node_feat_dim: int, edge_feat_dim: int, time_feat_dim: int, geo_feat_dim:int,
                 num_layers: int, num_neighbors: int, dropout: float, random_projections: RandomProjectionModule,src_geo_feature,dst_geo_feature, margin: float, init_curvs: list):
        """
        Embedding module of TPNet, which utilizes a multi-layer MLP-Mixer as its backbone.
        :param node_raw_features: Tensor, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: Tensor, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :param time_encoder: TimeEncoder
        :param node_feat_dim: int, dimension of node features
        :param edge_feat_dim: int, dimension of edge features
        :param time_feat_dim:  int, dimension of time features (encodings)
        :param num_layers: int, number of MLP-Mixer layers
        :param dropout: float, dropout rate
        """
        super(DyGMoCEEmbedding, self).__init__()

        self.src_geo_feature = src_geo_feature
        self.dst_geo_feature = dst_geo_feature
        self.node_raw_features = node_raw_features
        self.edge_raw_features = edge_raw_features
        self.neighbor_sampler = neighbor_sampler
        self.time_encoder = time_encoder
        self.geo_encoder = geo_encoder
        self.node_feat_dim = node_feat_dim
        self.edge_feat_dim = edge_feat_dim
        self.time_feat_dim = time_feat_dim
        self.geo_feat_dim = geo_feat_dim
        self.num_layers = num_layers
        self.num_neighbors = num_neighbors
        self.dropout = dropout
        self.random_projections = random_projections
        self.margin = margin
        if self.random_projections is None:
            self.random_feature_dim = 0
        else:
            self.random_feature_dim = self.random_projections.pair_wise_feature_dim * 2
        self.projection_layer = nn.Sequential(
            nn.Linear(node_feat_dim + edge_feat_dim + time_feat_dim + self.random_feature_dim + self.geo_feat_dim, self.node_feat_dim * 2),
            nn.ReLU(), nn.Linear(self.node_feat_dim * 2, self.node_feat_dim))
        # self.mlp_mixers = nn.ModuleList([
        #     MLPMixer(num_tokens=self.num_neighbors, num_channels=self.node_feat_dim,
        #              token_dim_expansion_factor=0.5,
        #              channel_dim_expansion_factor=4.0, dropout=self.dropout)
        #     for _ in range(self.num_layers)
        # ])
        self.init_curvs = init_curvs
        self.manifolds = nn.ModuleList()
        for curv in self.init_curvs:
            self.manifolds.append(geoopt.Stereographic(k=curv, learnable=False)) 

        self.transformers = nn.ModuleList([
            RiemannTransformerEncoder(manifolds=self.manifolds, init_curvs=self.init_curvs, attention_dim=self.node_feat_dim, dropout=self.dropout)
            for _ in range(self.num_layers)
        ])
        

    def compute_node_temporal_embeddings(self, node_ids: np.ndarray, src_node_ids: np.ndarray,
                                         dst_node_ids: np.ndarray, node_interact_times: np.ndarray,
                                         src_node_hyperbolicities, dst_node_hyperbolicities):
        """
        given memory, node ids node_ids, and the corresponding time node_interact_times, return the temporal embeddings.
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        """
        n_nodes = int(len(node_ids)/2)
        device = self.node_raw_features.device
        # get temporal neighbors, including neighbor ids, edge ids and time information
        # neighbor_node_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_edge_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_times ndarray, shape (batch_size, num_neighbors)
        neighbor_node_ids, neighbor_edge_ids, neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=self.num_neighbors)
        # get node features, shape (batch,num_neighbors,node_feat_dim)
        neighbor_node_features = self.node_raw_features[torch.from_numpy(neighbor_node_ids)]
        neighbor_delta_times = torch.from_numpy(node_interact_times[:, np.newaxis] - neighbor_times).float().to(device)
        # scale the delta times
        neighbor_delta_times = torch.log(neighbor_delta_times + 1.0)
        # get time encoding, shape (batch,num_neighbors, time_feat_dim)
        neighbor_time_features = self.time_encoder(neighbor_delta_times)
        # get edge features, shape (batch,num_neighors,edge_feat_dim)
        neighbor_edge_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)]
        
        src_neighbor_geo_feature = self.src_geo_feature[torch.from_numpy(neighbor_edge_ids[:n_nodes,:])]
        dst_neighbor_geo_feature = self.dst_geo_feature[torch.from_numpy(neighbor_edge_ids[n_nodes:,:])]
        neighbor_geo_feature = torch.concat((src_neighbor_geo_feature,dst_neighbor_geo_feature),dim=0)
        # neighbor_geo_feature = self.geo_encoder(neighbor_geo_feature)
        
        node_hyperbolicities = torch.cat((src_node_hyperbolicities,dst_node_hyperbolicities),dim=0)
        
        # assign relative encodings for neighbor nodes
        # given a source node u, a destination ndoe v, and a target node w (neighbor of u or v)
        # its relative encoding is [r_{w|u},r_{w|v}], where r_{w|u}/r_{w|v} is the pairwise feature
        # given by the calling the get_pair_wise_feature(w,u)/get_pair_wise_feature(w,v) of the RandomProjectionModule
        if self.random_projections is not None:
            # [2*batch*num_neighbors,random_feature_dim]
            concat_neighbor_random_features = self.random_projections.get_pair_wise_feature(
                src_node_ids=np.tile(neighbor_node_ids.reshape(-1), 2),
                dst_node_ids=np.concatenate(
                    [np.repeat(src_node_ids, self.num_neighbors), np.repeat(dst_node_ids, self.num_neighbors)]))
            # [batch,num_neighbors,random_feature_dim*2]
            neighbor_random_features = torch.cat(
                [concat_neighbor_random_features[:len(node_ids) * self.num_neighbors],
                 concat_neighbor_random_features[len(node_ids) * self.num_neighbors:]],
                dim=1).reshape(len(node_ids), self.num_neighbors, -1)
            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features, neighbor_random_features, neighbor_geo_feature],
                dim=2)
        else:
            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features], dim=2)

        # shape (batch, num_neighbors, node_feat_dim)
        embeddings = self.projection_layer(neighbor_combine_features)
        # mask the pad nodes (i.e., id = 0)
        embeddings.masked_fill(torch.from_numpy(neighbor_node_ids == 0)[:, :, None].to(device), 0)
        regular_loss = 0
        for transformer in self.transformers:
            embeddings,atten_node_curv, ffn_node_curv  = transformer(embeddings)
            regular_loss += self.compute_regular_loss(atten_node_curv, node_hyperbolicities)
            regular_loss += self.compute_regular_loss(ffn_node_curv, node_hyperbolicities)
        # shape (batch, node_feat_dim)
        embeddings = torch.mean(embeddings, dim=1)

        return embeddings,regular_loss

    def compute_regular_loss(self, node_curv, hyperbolicities):
   
        threshold = self.margin
        
        hi = hyperbolicities.unsqueeze(0)
        hj = hyperbolicities.unsqueeze(1)
        ci = node_curv.unsqueeze(0)
        cj = node_curv.unsqueeze(1)

        # Masks
        mask_less = hi < hj
        mask_greater = hi > hj
        mask_equal = hi == hj

        # 1. δ_i < δ_j ⇒ c_i + threshold < c_j
        loss_less = F.relu(threshold + ci - cj) * mask_less

        # 2. δ_i > δ_j ⇒ c_i - threshold > c_j
        loss_greater = F.relu(threshold + cj - ci) * mask_greater

        # 3. δ_i == δ_j ⇒ |c_i - c_j| <= threshold
        diff_equal = (ci - cj).abs()
        loss_equal = F.relu(diff_equal - threshold) * mask_equal

        # Combine
        valid_pairs = mask_less | mask_greater | mask_equal
        denom = valid_pairs.sum().clamp(min=1)
        reg_loss = (loss_less + loss_greater + loss_equal).sum() / denom
        return reg_loss
    
class FeedForwardNet(nn.Module):

    def __init__(self, input_dim: int, dim_expansion_factor: float, dropout: float = 0.0):
        """
        two-layered MLP with GELU activation function.
        :param input_dim: int, dimension of input
        :param dim_expansion_factor: float, dimension expansion factor
        :param dropout: float, dropout rate
        """
        super(FeedForwardNet, self).__init__()

        self.input_dim = input_dim
        self.dim_expansion_factor = dim_expansion_factor
        self.dropout = dropout

        self.ffn = nn.Sequential(nn.Linear(in_features=input_dim, out_features=int(dim_expansion_factor * input_dim)),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(in_features=int(dim_expansion_factor * input_dim), out_features=input_dim),
                                 nn.Dropout(dropout))

    def forward(self, x: torch.Tensor):
        """
        feed forward net forward process
        :param x: Tensor, shape (*, input_dim)
        :return:
        """
        return self.ffn(x)


class MLPMixer(nn.Module):

    def __init__(self, num_tokens: int, num_channels: int, token_dim_expansion_factor: float = 0.5,
                 channel_dim_expansion_factor: float = 4.0, dropout: float = 0.0):
        """
        MLP Mixer.
        :param num_tokens: int, number of tokens
        :param num_channels: int, number of channels
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        """
        super(MLPMixer, self).__init__()

        self.token_norm = nn.LayerNorm(num_tokens)
        self.token_feedforward = FeedForwardNet(input_dim=num_tokens, dim_expansion_factor=token_dim_expansion_factor,
                                                dropout=dropout)

        self.channel_norm = nn.LayerNorm(num_channels)
        self.channel_feedforward = FeedForwardNet(input_dim=num_channels,
                                                  dim_expansion_factor=channel_dim_expansion_factor,
                                                  dropout=dropout)

    def forward(self, input_tensor: torch.Tensor):
        """
        mlp mixer to compute over tokens and channels
        :param input_tensor: Tensor, shape (batch_size, num_tokens, num_channels)
        :return:
        """
        # mix tokens
        # Tensor, shape (batch_size, num_channels, num_tokens)
        hidden_tensor = self.token_norm(input_tensor.permute(0, 2, 1))
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.token_feedforward(hidden_tensor).permute(0, 2, 1)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + input_tensor

        # mix channels
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.channel_norm(output_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.channel_feedforward(hidden_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + output_tensor

        return output_tensor


class RiemannTransformerEncoder(nn.Module):

    def __init__(self, manifolds,init_curvs, attention_dim: int,  dropout: float = 0.1):
        """
        Transformer encoder.
        :param attention_dim: int, dimension of the attention vector
        :param num_heads: int, number of attention heads
        :param dropout: float, dropout rate
        """
        super(RiemannTransformerEncoder, self).__init__()
        # use the MultiheadAttention implemented by PyTorch
        
        self.manifolds = manifolds
        self.init_curvs = init_curvs
        
        self.multi_head_attention = MoeMultiHeadAttention(manifolds = manifolds, init_curvs= init_curvs,embed_size=attention_dim, dropout=dropout)
        # 200
        self.dropout = nn.Dropout(dropout)

        # self.linear_layers = nn.ModuleList([
        #     RiemannLinear(manifold, attention_dim, 4 * attention_dim),
        #     RiemannLinear(manifold, 4 * attention_dim, attention_dim)
        # ])
        self.ffn = MoeFFN(manifolds = manifolds, init_curvs= init_curvs,embed_size=attention_dim,  dropout=dropout)
        self.linear_layers = nn.ModuleList([
            nn.Linear(in_features=attention_dim, out_features=4 * attention_dim),
            nn.Linear(in_features=4 * attention_dim, out_features=attention_dim)
        ])
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(attention_dim),
            nn.LayerNorm(attention_dim)
        ])
        

    def forward(self, inputs: torch.Tensor):
        """
        encode the inputs by Transformer encoder
        :param inputs: Tensor, shape (batch_size, num_patches, self.attention_dim)
        :return:
        """
        # note that the MultiheadAttention module accept input data with shape (seq_length, batch_size, input_dim), so we need to transpose the input
        # Tensor, shape (num_patches, batch_size, self.attention_dim)
        transposed_inputs = inputs.transpose(0, 1)
        # Tensor, shape (batch_size, num_patches, self.attention_dim)
        transposed_inputs = self.norm_layers[0](transposed_inputs)
        # Tensor, shape (batch_size, num_patches, self.attention_dim)
        # hidden_states = self.multi_head_attention(input_q=transposed_inputs, input_k=transposed_inputs, input_v=transposed_inputs).transpose(0, 1)
        hidden_states, atten_node_curv = self.multi_head_attention(input_q=transposed_inputs, input_k=transposed_inputs, input_v=transposed_inputs)
        # Tensor, shape (batch_size, num_patches, self.attention_dim)
        outputs = inputs + self.dropout(hidden_states)

        hidden_states, ffn_node_curv = self.ffn(outputs)
        
        
        # Tensor, shape (batch_size, num_patches, self.attention_dim)
        outputs = outputs + self.dropout(hidden_states)
        return outputs,atten_node_curv,ffn_node_curv


class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)
        logits = torch.mean(logits, dim = 1)
        #Noise logits
        noise_logits = self.noise_linear(mh_output)
        noise_logits = torch.mean(noise_logits, dim = 1)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices
 
 
class ContinuousTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(ContinuousTopkRouter, self).__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        # layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_output: (B, L, D)
        
        logits = self.topkroute_linear(mh_output)  # (B, L, E)
        logits = torch.mean(logits, dim = 1)
        noise_logits = self.noise_linear(mh_output)  # (B, L, E)
        noise_logits = torch.mean(noise_logits, dim = 1)
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise  # (B, E)

        B, E = noisy_logits.shape
        k = self.top_k

        if k == 1:
            # fallback to normal top-1
            indices = noisy_logits.argmax(dim=-1, keepdim=True)  # (B, 1)
        else:
            # Use unfold to get all contiguous windows of size k
            # Shape: (B, L, E - k + 1, k)
            unfolded = noisy_logits.unfold(dimension=-1, size=k, step=1)  # (B,  E - k + 1, k)
            window_sums = unfolded.sum(dim=-1)  # (B,  E - k + 1)

            # Find the best start index for each (B, )
            best_start = window_sums.argmax(dim=-1)  # (B, )

            # Create indices: [start, start+1, ..., start+k-1]
            arange_k = torch.arange(k, device=noisy_logits.device).view( 1, k)  # (1, 1, k)
            best_start_expanded = best_start.unsqueeze(-1)  # (B,  1)
            indices = best_start_expanded + arange_k  # (B,  k)

        # Now build sparse gating
        zeros = torch.full_like(noisy_logits, float('-inf'))  # (B,  E)
        # Scatter the selected logits (we use original noisy_logits values, not summed ones)
        selected_logits = torch.gather(noisy_logits, dim=-1, index=indices)  # (B,  k)
        sparse_logits = zeros.scatter(-1, indices, selected_logits)  # fill only selected positions

        router_output = F.softmax(sparse_logits, dim=-1)  # (B,  E)

        return router_output, indices
 
    
class RiemannScaledProduct(nn.Module):
    def __init__(self, manifold, kappa, dropout: float = 0.1):
        super(RiemannScaledProduct, self).__init__()
        
        self.manifold = manifold
        self.kappa = kappa
        self.dropout = nn.Dropout(dropout)
        
    def rie_polar_decompose(self,x, eps=1e-15):
        # r = ||x|| over last dimension
        r = torch.norm(x, dim=-1).clamp_min(1e-15)  # [N, L]
        # u = x / ||x||
        u = x / r.unsqueeze(-1)   # [N, L, d]
   
   
        kappa = torch.tensor(self.kappa, dtype=x.dtype, device=x.device)
        

        

        if self.kappa<0:
            sqrt_neg_k = torch.sqrt(-kappa)
            z = (sqrt_neg_k * r).clamp(-1.0 + eps, 1.0 - eps)
            rho = 2.0 / sqrt_neg_k * torch.atanh(z)  # [N, L]
        elif self.kappa>0:
            sqrt_k = torch.sqrt(kappa)
            rho = 2.0 / sqrt_k * torch.atan(sqrt_k * r)
        else:  # kappa == 0
            rho = 2.0 * r

        return r, u, rho


    def forward(self, mask, q, k, v, scale):
        q = q[mask]
        k = k[mask]
        v = v[mask]
        
        q_r, q_u, q_rho = self.rie_polar_decompose(q)
        k_r, k_u, k_rho = self.rie_polar_decompose(k)
        v_r, v_u, v_rho = self.rie_polar_decompose(v)
        
     
        if not torch.is_tensor(self.kappa):
            kappa = torch.tensor(self.kappa, dtype=q.dtype, device=q.device)
        
        dot = (q_u.unsqueeze(2) * k_u.unsqueeze(1)).sum(dim=-1)
        dot = dot.clamp(-1.0 + 1e-6, 1.0 - 1e-6)
      
        rho_Q_exp = q_rho.unsqueeze(-1)  # [N, Lq, 1]
        rho_K_exp = k_rho.unsqueeze(1)   # [N, 1, Lk]
        if self.kappa<0:
            sqrt_neg_k = torch.sqrt(-kappa)
            arg_i = (sqrt_neg_k * rho_Q_exp).clamp_max(20.0)
            arg_j = (sqrt_neg_k * rho_K_exp).clamp_max(20.0)
            cosh_i = torch.cosh(arg_i)
            cosh_j = torch.cosh(arg_j)
            sinh_i = torch.sinh(arg_i)
            sinh_j = torch.sinh(arg_j)

            # ----------------------------------------------------------
            # 3. 超曲余弦定理
            # ----------------------------------------------------------
            cosh_d = cosh_i * cosh_j - sinh_i * sinh_j * dot  # [N, Lq, Lk]
            cosh_d = cosh_d.clamp_min(1.0 + 1e-6)
            distance = torch.acosh(cosh_d) / sqrt_neg_k
        
        elif self.kappa == 0:
            d2 = rho_Q_exp**2 + rho_K_exp**2 - 2 * rho_Q_exp * rho_K_exp * dot
            distance = torch.sqrt(d2.clamp_min(0.0) + 1e-12)
        
        else:  # kappa > 0
            sqrt_k = torch.sqrt(kappa)

            cos_i = torch.cos(sqrt_k * rho_Q_exp)
            cos_j = torch.cos(sqrt_k * rho_K_exp)
            sin_i = torch.sin(sqrt_k * rho_Q_exp)
            sin_j = torch.sin(sqrt_k * rho_K_exp)

            cos_d = cos_i * cos_j + sin_i * sin_j * dot
            cos_d = torch.clamp(cos_d, -1.0 + 1e-6, 1.0 - 1e-6)  # 避免数值越界
            distance = torch.acos(cos_d) / sqrt_k
        
        distance = torch.nan_to_num(distance, nan=1e6, posinf=1e6, neginf=1e6)
        attn = F.softmax(-distance, dim=-1) #[B,L,L]

        rho = v_rho.unsqueeze(1).unsqueeze(-1) 
        u = v_u.unsqueeze(1)
        w = attn.unsqueeze(-1) 
        output = 0.5* w * rho *u
        output = torch.sum(output,dim=2)
        
        return output
    
class MoeMultiHeadAttention(nn.Module):
    def __init__(self, manifolds,init_curvs, embed_size, dropout=0.1):
        super(MoeMultiHeadAttention, self).__init__()
        self.manifolds = manifolds

        self.embed_size = embed_size

        self.dropout = nn.Dropout(dropout)
        self.in_proj_weights = nn.ModuleList()
        for i, manifold in enumerate(manifolds):
            self.in_proj_weights.append(RiemannLinear(manifold, embed_size, 3*embed_size, dropout))
        
        self.init_curvs = init_curvs
        self.out_proj  = nn.Linear(embed_size, embed_size)
        self.num_experts = len(self.init_curvs)
        # self.router = TopkRouter(embed_size, self.num_experts, top_k=2)
        self.router = ContinuousTopkRouter(embed_size, self.num_experts, top_k=2)
      
        self.experts = nn.ModuleList()
        for i, manifold in enumerate(self.manifolds):
            self.experts.append(RiemannScaledProduct(manifold, kappa = init_curvs[i]))
            
    
    
    def forward(self, input_q, input_k, input_v):
        tgt_len, bsz, embed_dim = input_q.shape 
        src_len, _, _ = input_k.shape
        
        q_list, k_list, v_list = [], [], []
        for i, manifold in enumerate(self.manifolds):
            q = manifold.proju(manifold.origin(input_q.shape), input_q)
            q = manifold.expmap0(q, project=True)
            q, k, v = self.in_proj_weights[i](q).chunk(3, dim = -1)
            q_list.append(q.contiguous())
            k_list.append(k.contiguous())
            v_list.append(v.contiguous())
        q = torch.stack(q_list).transpose(1, 2)
        k = torch.stack(k_list).transpose(1, 2)
        v = torch.stack(v_list).transpose(1, 2)

        # Compute the attention scores
   
        attn_output,node_curv = self.moe_scaled_dot_product_attention(input_q, q, k, v) #[200, 51, 200]

        return attn_output,node_curv
    
    def moe_scaled_dot_product_attention(self, x, q, k, v): #
        _, B, Nt, E = q.shape
        scale = math.sqrt(E)
        
        # x [L, B, d]->x [B, L, d]
        x = x.transpose(0,1)
        ##MOE
        gating_output, indices = self.router(x) #[B,k]
        
        init_curvs = torch.tensor(self.init_curvs, device = x.device)
        node_curv = torch.sum(gating_output * init_curvs, dim =1)
        
        final_output = torch.zeros_like(x)


        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)

            if expert_mask.any():
                expert_output = expert(expert_mask,q[i],k[i],v[i],scale)
                gating_scores = gating_output[expert_mask, i].unsqueeze(1).unsqueeze(2)
                weighted_output = expert_output * gating_scores
                final_output[expert_mask] += weighted_output
        
    
        
        
        return final_output,node_curv

   

class MoeFFN(nn.Module):
    def __init__(self, manifolds,init_curvs, embed_size, dropout=0.1):
        super(MoeFFN, self).__init__()
        self.manifolds = manifolds
        self.embed_size = embed_size
        self.dropout = nn.Dropout(dropout)


        self.init_curvs = init_curvs
        # self.out_proj = RiemannLinear(manifold, embed_size, embed_size, dropout)
        self.num_experts = len(self.init_curvs)
        # self.router = TopkRouter(embed_size, self.num_experts, top_k=2)
        self.router = ContinuousTopkRouter(embed_size, self.num_experts, top_k=2)
        self.experts = nn.ModuleList()
        for i, manifold in enumerate(self.manifolds):
            self.experts.append(RiemannMLP(manifold, embed_size, embed_size))
            
    def forward(self,x):
        gating_output, indices = self.router(x)
        #
        init_curvs = torch.tensor(self.init_curvs, device = x.device)
        node_curv = torch.sum(gating_output * init_curvs, dim =1)
                
        final_output = torch.zeros_like(x)



        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)

            if expert_mask.any():   
                expert_output = expert(x[expert_mask])
                gating_scores = gating_output[expert_mask, i].unsqueeze(1).unsqueeze(2)
                weighted_output = expert_output * gating_scores
                final_output[expert_mask] += weighted_output
              
            
        return final_output,node_curv


class RiemannMLP(nn.Module):
    def __init__(self, manifold, in_dim: int, out_dim: int, dropout: float=0.0):
        super(RiemannMLP, self).__init__()
        self.linear_layers = nn.ModuleList([
                RiemannLinear(manifold, in_dim,  in_dim*4, dropout = 0),
                RiemannLinear(manifold, in_dim*4,  out_dim, dropout = 0)
                ])
        self.manifold = manifold
        self.dropout = nn.Dropout(dropout)


    def forward(self, inputX):
        hyper_input = self.manifold.proju(self.manifold.origin(inputX.shape), inputX)
        hyper_input = self.manifold.expmap0(hyper_input, project=True)
        
        hidden_states1 = self.linear_layers[0](hyper_input)
        hidden_states1 = self.manifold.logmap0(hidden_states1)
        hidden_states1 = self.dropout(F.gelu(hidden_states1))
        hidden_states1 = self.manifold.proju(self.manifold.origin(hidden_states1.shape), hidden_states1)
        hidden_states1 = self.manifold.expmap0(hidden_states1, project=True)
        hidden_states = self.linear_layers[1](hidden_states1)
        out = self.manifold.logmap0(hidden_states)
        return out




class RiemannLinear(nn.Module):
    def __init__(self, manifold, in_dim: int, out_dim: int, dropout: float=0.0, use_bias: bool=True):
        super(RiemannLinear, self).__init__()
        self.manifold = manifold
        self.dropout = dropout
        self.use_bias = use_bias
        self.weight = nn.Parameter(torch.Tensor(out_dim, in_dim))
        self.bias = nn.Parameter(torch.Tensor(out_dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.constant_(self.bias, 0)

    def forward(self, x):
        drop_weight = F.dropout(self.weight, self.dropout, training=self.training)
        res = self.manifold.mobius_matvec(drop_weight, x, project=True)
        if self.use_bias:
            bias = self.manifold.proju(self.manifold.origin(self.bias.shape), self.bias)
            kappa_bias = self.manifold.expmap0(bias, project=True)
            res = self.manifold.mobius_add(res, kappa_bias, project=True)
        return res



