import numpy as np
import torch
import torch.nn as nn

from models.modules import TimeEncoder, TimeMixer
from utils.utils import NeighborSampler


class TGMixer(nn.Module):

    def __init__(self, node_raw_features: np.ndarray, edge_raw_features: np.ndarray, neighbor_sampler: NeighborSampler,
                 time_feat_dim: int, num_tokens: int, num_layers: int = 2, token_dim_expansion_factor: float = 0.5,
                 channel_dim_expansion_factor: float = 4.0, dropout: float = 0.1, device: str = 'cpu',
                 max_time_shift: float = 0.0):
        """
        TCL model.
        :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: neighbor sampler
        :param time_feat_dim: int, dimension of time features (encodings)
        :param num_tokens: int, number of tokens
        :param num_layers: int, number of transformer layers
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        :param device: str, device
        """
        super(TGMixer, self).__init__()

        self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device)

        self.neighbor_sampler = neighbor_sampler
        self.node_feat_dim = self.node_raw_features.shape[1]
        self.edge_feat_dim = self.edge_raw_features.shape[1]
        self.time_feat_dim = time_feat_dim
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.token_dim_expansion_factor = token_dim_expansion_factor
        self.channel_dim_expansion_factor = channel_dim_expansion_factor
        self.dropout = dropout
        self.device = device
        self.max_time_shift = max_time_shift

        self.num_channels = self.edge_feat_dim
        self.time_encoder = TimeEncoder(time_dim=time_feat_dim, parameter_requires_grad=False)
        self.projection_layer = nn.Linear(self.node_feat_dim + self.edge_feat_dim + time_feat_dim, self.num_channels)

        self.mlp_mixers = nn.ModuleList([
            MLPMixer(num_tokens=self.num_tokens, num_channels=self.num_channels,
                     token_dim_expansion_factor=self.token_dim_expansion_factor,
                     channel_dim_expansion_factor=self.channel_dim_expansion_factor, dropout=self.dropout,
                     max_time_shift=self.max_time_shift)
            for _ in range(self.num_layers)
        ])

        self.hidden = torch.zeros((1, 200, self.num_channels), requires_grad=False).float().to(self.device)
        self.cell = torch.zeros((1, 200, self.num_channels), requires_grad=False).float().to(self.device)

        self.output_layer = nn.Linear(in_features=self.num_channels + self.node_feat_dim, out_features=self.node_feat_dim, bias=True)

    def compute_src_dst_node_temporal_embeddings(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, neg_src_node_ids:np.ndarray, neg_dst_node_ids :np.ndarray,
                                                 node_interact_times: np.ndarray, num_neighbors: int = 20, time_gap: int = 2000):
        """
        compute source and destination node temporal embeddings
        :param src_node_ids: ndarray, shape (batch_size, )
        :param dst_node_ids: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :param num_neighbors: int, number of neighbors to sample for each node
        :param time_gap: int, time gap for neighbors to compute node features
        :return:
        """
        batch_size = src_node_ids.shape[0]
        train_link = len(neg_dst_node_ids) > 0
        _, _, src_neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=src_node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=num_neighbors)
        src_nodes_neighbor_latest_time_interval = torch.from_numpy(node_interact_times - np.max(src_neighbor_times, axis=1)).float().to(self.device)

        _, _, dst_neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=dst_node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=num_neighbors)
        dst_nodes_neighbor_latest_time_interval = torch.from_numpy(
            node_interact_times - np.max(dst_neighbor_times, axis=1)).float().to(self.device)

        if train_link:
            _, _, neg_src_neighbor_times = \
                self.neighbor_sampler.get_historical_neighbors(node_ids=neg_src_node_ids,
                                                               node_interact_times=node_interact_times,
                                                               num_neighbors=num_neighbors)
            neg_src_nodes_neighbor_latest_time_interval = torch.from_numpy(
                node_interact_times - np.max(neg_src_neighbor_times, axis=1)).float().to(self.device)

            _, _, neg_dst_neighbor_times = \
                self.neighbor_sampler.get_historical_neighbors(node_ids=neg_dst_node_ids,
                                                               node_interact_times=node_interact_times,
                                                               num_neighbors=num_neighbors)
            neg_dst_nodes_neighbor_latest_time_interval = torch.from_numpy(
                node_interact_times - np.max(neg_dst_neighbor_times, axis=1)).float().to(self.device)

            # (2 * batch_size, 1)
            src_nodes_neighbor_latest_time_interval = torch.cat([src_nodes_neighbor_latest_time_interval, neg_src_nodes_neighbor_latest_time_interval])
            dst_nodes_neighbor_latest_time_interval = torch.cat([dst_nodes_neighbor_latest_time_interval, neg_dst_nodes_neighbor_latest_time_interval])


        src_nodes_neighbor_latest_time_interval = torch.exp(- 2 * src_nodes_neighbor_latest_time_interval / self.max_time_shift).unsqueeze(1)
        dst_nodes_neighbor_latest_time_interval = torch.exp(- 2 * dst_nodes_neighbor_latest_time_interval / self.max_time_shift).unsqueeze(1)

        # Tensor, shape (batch_size, node_feat_dim)
        src_node_embeddings = self.compute_node_temporal_embeddings(node_ids=src_node_ids, node_interact_times=node_interact_times,
                                                                    num_neighbors=num_neighbors, time_gap=time_gap,
                                                                    softmax_time_interval=src_nodes_neighbor_latest_time_interval[:batch_size,:])
        # Tensor, shape (batch_size, node_feat_dim)
        dst_node_embeddings = self.compute_node_temporal_embeddings(node_ids=dst_node_ids, node_interact_times=node_interact_times,
                                                                    num_neighbors=num_neighbors, time_gap=time_gap,
                                                                    softmax_time_interval=torch.mean(dst_nodes_neighbor_latest_time_interval[:batch_size,:]))
        neg_src_node_embeddings, neg_dst_node_embeddings = None, None
        if train_link:
            neg_src_node_embeddings = self.compute_node_temporal_embeddings(node_ids=neg_src_node_ids,
                                                                        node_interact_times=node_interact_times,
                                                                        num_neighbors=num_neighbors, time_gap=time_gap,
                                                                        softmax_time_interval=src_nodes_neighbor_latest_time_interval[batch_size:, :])
            # Tensor, shape (batch_size, node_feat_dim)
            neg_dst_node_embeddings = self.compute_node_temporal_embeddings(node_ids=neg_dst_node_ids,
                                                                        node_interact_times=node_interact_times,
                                                                        num_neighbors=num_neighbors, time_gap=time_gap,
                                                                        softmax_time_interval=dst_nodes_neighbor_latest_time_interval[batch_size:, :])

        return src_node_embeddings, dst_node_embeddings, neg_src_node_embeddings, neg_dst_node_embeddings

    def compute_node_temporal_embeddings(self, node_ids: np.ndarray, node_interact_times: np.ndarray,
                                         num_neighbors: int = 20, time_gap: int = 2000,
                                         softmax_time_interval: torch.Tensor = None):
        """
        given node ids node_ids, and the corresponding time node_interact_times, return the temporal embeddings of nodes in node_ids
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        :param num_neighbors: int, number of neighbors to sample for each node
        :param time_gap: int, time gap for neighbors to compute node features
        :return:
        """
        # link encoder
        # get temporal neighbors, including neighbor ids, edge ids and time information
        # neighbor_node_ids, ndarray, shape (batch_size, num_neighbors)
        # neighbor_edge_ids, ndarray, shape (batch_size, num_neighbors)
        # neighbor_times, ndarray, shape (batch_size, num_neighbors)
        neighbor_node_ids, neighbor_edge_ids, neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=num_neighbors)

        nodes_neighbor_latest_time_interval = softmax_time_interval

        batch_size = neighbor_node_ids.shape[0]
        # Tensor, shape (batch_size, num_neighbors, edge_feat_dim)
        nodes_node_raw_features = self.node_raw_features[torch.from_numpy(neighbor_node_ids)]
        # Tensor, shape (batch_size, num_neighbors, edge_feat_dim)
        nodes_edge_raw_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)]
        # Tensor, shape (batch_size, num_neighbors, time_feat_dim)
        nodes_neighbor_time_features = self.time_encoder(timestamps=torch.from_numpy(node_interact_times[:, np.newaxis] - neighbor_times).float().to(self.device))

        # ndarray, set the time features to all zeros for the padded timestamp
        nodes_neighbor_time_features[torch.from_numpy(neighbor_node_ids == 0)] = 0.0
        self.hidden[:,:neighbor_node_ids.shape[0],], self.cell[:,:neighbor_node_ids.shape[0],] = 0.0, 0.0

        # Tensor, shape (batch_size, num_neighbors, node_feat_dim + edge_feat_dim + time_feat_dim)
        combined_features = torch.cat([nodes_node_raw_features, nodes_edge_raw_features, nodes_neighbor_time_features], dim=-1)
        # Tensor, shape (batch_size, num_neighbors, num_channels)
        combined_features = self.projection_layer(combined_features)

        # Backbone of TGMixer
        # Tensor, shape (batch_size, num_neighbors, num_channels)
        hidden, cell = self.hidden[:,:batch_size,].data.clone(), self.cell[:,:batch_size,].data.clone()
        for mlp_mixer in self.mlp_mixers:
            combined_features, hidden, cell = mlp_mixer(input_tensor=combined_features,
                                          hidden=hidden,
                                          cell=cell,
                                          time_interval=nodes_neighbor_latest_time_interval)
        self.hidden[:, :batch_size, ], self.cell[:, :batch_size, ] = hidden.data.clone(), cell.data.clone()

        # Tensor, shape (batch_size, num_channels)
        node_embeddings = torch.mean(combined_features, dim=1)

        return node_embeddings

    def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler):
        """
        set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :return:
        """
        self.neighbor_sampler = neighbor_sampler
        if self.neighbor_sampler.sample_neighbor_strategy in ['uniform', 'time_interval_aware']:
            assert self.neighbor_sampler.seed is not None
            self.neighbor_sampler.reset_random_state()


class FeedForwardNet(nn.Module):

    def __init__(self, input_dim: int, dim_expansion_factor: float, dropout: float = 0.0):
        """
        two-layered MLP with GELU activation function.
        :param input_dim: int, dimension of input
        :param dim_expansion_factor: float, dimension expansion factor
        :param dropout: float, dropout rate
        """
        super(FeedForwardNet, self).__init__()

        self.input_dim = input_dim
        self.dim_expansion_factor = dim_expansion_factor
        self.dropout = dropout

        self.ffn = nn.Sequential(nn.Linear(in_features=input_dim, out_features=int(dim_expansion_factor * input_dim)),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(in_features=int(dim_expansion_factor * input_dim), out_features=input_dim),
                                 nn.Dropout(dropout))

    def forward(self, x: torch.Tensor):
        """
        feed forward net forward process
        :param x: Tensor, shape (*, input_dim)
        :return:
        """
        return self.ffn(x)


class MLPMixer(nn.Module):

    def __init__(self, num_tokens: int, num_channels: int, token_dim_expansion_factor: float = 0.5,
                 channel_dim_expansion_factor: float = 4.0, dropout: float = 0.0, max_time_shift: float = 0.0):
        """
        MLP Mixer.
        :param num_tokens: int, number of tokens
        :param num_channels: int, number of channels
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        """
        super(MLPMixer, self).__init__()

        self.token_norm = nn.LayerNorm(num_tokens)
        self.token_feedforward = FeedForwardNet(input_dim=num_tokens, dim_expansion_factor=token_dim_expansion_factor,
                                                dropout=dropout)

        self.channel_norm = nn.LayerNorm(num_channels)

        self.lstm_mixer = TimeMixer(in_dim=num_channels, hid_dim=num_channels)
        self.max_time_shift = max_time_shift

    def forward(self, input_tensor: torch.Tensor, hidden: torch.Tensor, cell: torch.Tensor, time_interval: torch.Tensor):
        """
        mlp mixer to compute over tokens and channels
        :param input_tensor: Tensor, shape (batch_size, num_tokens, num_channels)
        :return:
        """
        # mix tokens
        # Tensor, shape (batch_size, num_channels, num_tokens)
        hidden_tensor = self.token_norm(input_tensor.permute(0, 2, 1))
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.token_feedforward(hidden_tensor).permute(0, 2, 1)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + input_tensor

        # mix channels
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.channel_norm(output_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor, hidden_output, cell_output = self.lstm_mixer(input=hidden_tensor,
                                                                    cell=cell,
                                                                    hidden=hidden,
                                                                    time_interval=time_interval,
                                                                    max_time_shift=self.max_time_shift)

        # hidden_tensor = self.channel_feedforward(hidden_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + input_tensor

        return output_tensor, hidden_output, cell_output
