import numpy as np
import torch
import torch.nn as nn
from models.modules import TimeEncoder, LinkPredictor_v1, MultiHeadAttention
from utils.utils import NeighborSampler
from numba.typed import List
from numba import jit, prange


@jit(nopython=True, parallel=True)
def fast_convert_neighbor(nodes, timestamps, nodes_neighbor_times, time_step, time_length):
    node_raw_feature = np.zeros((len(nodes), time_step), dtype=np.float32)
    for id in prange(len(nodes)):
        node, time = nodes[id], timestamps[id]
        neighbor_times = nodes_neighbor_times[node]
        for neighbor_time in neighbor_times:
            if neighbor_time >= time:
                break
            else:
                pos = int(np.ceil((time - neighbor_time) / time_length) - 1)
                if pos < time_step:
                    node_raw_feature[id, pos] += 1
    return node_raw_feature


class NodeFeature(nn.Module):
    def __init__(self, datas, time_step, time_length, node_feature_dim, device):
        super(NodeFeature, self).__init__()
        self.datas = datas
        self.nodes_neighbor_ids = []
        self.nodes_edge_ids = []
        self.nodes_neighbor_times = []
        self.nb_nodes_neighbor_ids = None
        self.nb_nodes_edge_ids = None
        self.nb_nodes_neighbor_times = None

        self.convert_data_to_neighbor(self.datas)
        self.time_step = time_step
        self.time_length = time_length
        self.node_feature_dim = node_feature_dim

        self.node_feature_map = nn.Sequential(nn.Linear(self.time_step, int(np.floor(np.sqrt(node_feature_dim)))),
                                              nn.ReLU(),
                                              nn.Linear(int(np.floor(np.sqrt(node_feature_dim))), node_feature_dim))
        self.device = device

    def convert_data_to_neighbor(self, datas):
        max_node_id = max([max(data.src_node_ids.max(), data.dst_node_ids.max()) for data in datas])
        adj_list = [[] for _ in range(max_node_id + 1)]
        for data in datas:
            for src_node_id, dst_node_id, edge_id, node_interact_time in zip(data.src_node_ids, data.dst_node_ids,
                                                                             data.edge_ids, data.node_interact_times):
                adj_list[src_node_id].append((dst_node_id, edge_id, node_interact_time))
                adj_list[dst_node_id].append((src_node_id, edge_id, node_interact_time))
        for i in range(len(adj_list)):
            sorted_per_node_neighbors = sorted(adj_list[i], key=lambda x: x[1])
            for j in range(0, len(sorted_per_node_neighbors) - 1):
                assert sorted_per_node_neighbors[j + 1][2] >= sorted_per_node_neighbors[
                    j][2], "edge id increasing do not mean time increasing"
            self.nodes_neighbor_ids.append(np.array([x[0] for x in sorted_per_node_neighbors]))
            self.nodes_edge_ids.append(np.array([x[1] for x in sorted_per_node_neighbors]))
            self.nodes_neighbor_times.append(np.array([x[2] for x in sorted_per_node_neighbors]))
        self.nb_nodes_neighbor_ids = List([np.array(x, dtype=np.int64) for x in self.nodes_neighbor_ids])
        self.nb_nodes_edge_ids = List([np.array(x, dtype=np.int64) for x in self.nodes_edge_ids])
        self.nb_nodes_neighbor_times = List([np.array(x, dtype=np.float64) for x in self.nodes_neighbor_times])

    def forward(self, nodes, timestamps):
        node_shape = nodes.shape
        nodes, timestamps = nodes.reshape(-1), timestamps.reshape(-1)
        node_raw_feature = fast_convert_neighbor(nodes, timestamps, self.nb_nodes_neighbor_times,
                                                 self.time_step, self.time_length)
        node_raw_feature = torch.from_numpy(node_raw_feature).to(self.device)
        node_feature = self.node_feature_map(node_raw_feature)
        return node_feature.reshape(*node_shape, -1)
