import torch
from torch import nn
import numpy as np
import math


from tgnnexplainer.xgraph.models.ext.tgn.model.temporal_attention import (
    TemporalAttentionLayer,
)

# from model.temporal_attention import TemporalAttentionLayer


class EmbeddingModule(nn.Module):
    def __init__(
        self,
        node_features,
        edge_features,
        memory,
        neighbor_finder,
        time_encoder,
        n_layers,
        n_node_features,
        n_edge_features,
        n_time_features,
        embedding_dimension,
        device,
        dropout,
    ):
        super(EmbeddingModule, self).__init__()
        self.node_features = node_features
        self.edge_features = edge_features
        # self.memory = memory
        self.neighbor_finder = neighbor_finder
        self.time_encoder = time_encoder
        self.n_layers = n_layers
        self.n_node_features = n_node_features
        self.n_edge_features = n_edge_features
        self.n_time_features = n_time_features
        self.dropout = dropout
        self.embedding_dimension = embedding_dimension
        self.device = device

    def compute_embedding(
        self,
        memory,
        source_nodes,
        timestamps,
        n_layers,
        n_neighbors=20,
        time_diffs=None,
        use_time_proj=True,
    ):
        pass


class IdentityEmbedding(EmbeddingModule):
    def compute_embedding(
        self,
        memory,
        source_nodes,
        timestamps,
        n_layers,
        n_neighbors=20,
        time_diffs=None,
        use_time_proj=True,
    ):
        return memory[source_nodes, :], None


class TimeEmbedding(EmbeddingModule):
    def __init__(
        self,
        node_features,
        edge_features,
        memory,
        neighbor_finder,
        time_encoder,
        n_layers,
        n_node_features,
        n_edge_features,
        n_time_features,
        embedding_dimension,
        device,
        n_heads=2,
        dropout=0.1,
        use_memory=True,
        n_neighbors=1,
    ):
        super(TimeEmbedding, self).__init__(
            node_features,
            edge_features,
            memory,
            neighbor_finder,
            time_encoder,
            n_layers,
            n_node_features,
            n_edge_features,
            n_time_features,
            embedding_dimension,
            device,
            dropout,
        )

        class NormalLinear(nn.Linear):
            # From Jodie code
            def reset_parameters(self):
                stdv = 1.0 / math.sqrt(self.weight.size(1))
                self.weight.data.normal_(0, stdv)
                if self.bias is not None:
                    self.bias.data.normal_(0, stdv)

        self.embedding_layer = NormalLinear(1, self.n_node_features)

    def compute_embedding(
        self,
        memory,
        source_nodes,
        timestamps,
        n_layers,
        n_neighbors=20,
        time_diffs=None,
        use_time_proj=True,
    ):
        source_embeddings = memory[source_nodes, :] * (
            1 + self.embedding_layer(time_diffs.unsqueeze(1))
        )

        return source_embeddings, None


class GraphEmbedding(EmbeddingModule):
    def __init__(
        self,
        node_features,
        edge_features,
        memory,
        neighbor_finder,
        time_encoder,
        n_layers,
        n_node_features,
        n_edge_features,
        n_time_features,
        embedding_dimension,
        device,
        n_heads=2,
        dropout=0.1,
        use_memory=True,
    ):
        super(GraphEmbedding, self).__init__(
            node_features,
            edge_features,
            memory,
            neighbor_finder,
            time_encoder,
            n_layers,
            n_node_features,
            n_edge_features,
            n_time_features,
            embedding_dimension,
            device,
            dropout,
        )

        self.use_memory = use_memory
        self.device = device
        self.atten_weights_list = []
        self.n_heads = n_heads

    def compute_embedding(
        self,
        memory,
        source_nodes,
        timestamps,
        n_layers,
        n_neighbors=20,
        time_diffs=None,
        use_time_proj=True,
        edge_idx_preserve_list=None,
        candidate_weights_dict=None,
    ):
        """Recursive implementation of curr_layers temporal graph attention layers.

        src_idx_l [batch_size]: users / items input ids.
        cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
        curr_layers [scalar]: number of temporal convolutional layers to stack.
        num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
        """

        assert n_layers >= 0

        source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
        timestamps_torch = torch.unsqueeze(
            torch.from_numpy(timestamps).float().to(self.device), dim=1
        )

        # query node always has the start time -> time span == 0
        source_nodes_time_embedding = self.time_encoder(
            torch.zeros_like(timestamps_torch)
        )

        source_node_features = self.node_features[source_nodes_torch, :]

        if self.use_memory:
            source_node_features = memory[source_nodes, :] + source_node_features

        if n_layers == 0:
            return source_node_features
        else:
            (
                neighbors,
                edge_idxs,
                edge_times,
            ) = self.neighbor_finder.get_temporal_neighbor(
                source_nodes,
                timestamps,
                num_neighbors=n_neighbors,
                edge_idx_preserve_list=edge_idx_preserve_list,
            )

            neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)

            edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)

            edge_deltas = timestamps[:, np.newaxis] - edge_times

            # edge_times = torch.from_numpy(edge_times).float().to(self.device)

            edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)

            neighbors = neighbors.flatten()
            # import ipdb; ipdb.set_trace()
            neighbor_embeddings = self.compute_embedding(
                memory,
                neighbors,
                # np.repeat(timestamps, n_neighbors),
                edge_times.flatten(),  # NOTE: important! otherwise igh_finder cannot find some neighbors that it should find.
                n_layers=n_layers - 1,
                n_neighbors=n_neighbors,
                edge_idx_preserve_list=edge_idx_preserve_list,
                candidate_weights_dict=candidate_weights_dict,
            )

            effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
            neighbor_embeddings = neighbor_embeddings.view(
                len(source_nodes), effective_n_neighbors, -1
            )
            edge_time_embeddings = self.time_encoder(edge_deltas_torch)

            edge_features = self.edge_features[edge_idxs, :]

            # mask = edge_idxs == 0 # NOTE: True to be masked out, i.e., 0 positions
            position0 = edge_idxs == 0
            mask = torch.zeros_like(position0).to(dtype=torch.float32)
            mask[position0] = -1e10

            # import ipdb; ipdb.set_trace()
            if candidate_weights_dict is not None:
                # TODO: support for pg explainer.
                # import ipdb; ipdb.set_trace()
                position0 = edge_idxs == 0
                mask = torch.zeros_like(position0).to(dtype=torch.float32)

                event_idxs = candidate_weights_dict["candidate_events"]
                event_weights = candidate_weights_dict["edge_weights"]
                for i, e_idx in enumerate(event_idxs):
                    indices = edge_idxs == e_idx
                    mask[indices] = event_weights[i]

                mask[
                    position0
                ] = (
                    -1e10
                )  # because the addition in torch's multi-head attention implementation
                # import ipdb; ipdb.set_trace()

            source_embedding, atten_weights = self.aggregate(
                n_layers,
                source_node_features,
                source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings,
                edge_features,
                mask,
            )

            # preserve_mask = edge_idxs != 0
            # edge_idxs = edge_idxs[preserve_mask]
            # atten_weights = atten_weights[preserve_mask].reshape((self.n_heads, neighbors_torch.shape[1]))
            self.atten_weights_list.append(
                {
                    "layer": n_layers,
                    "src_nodes": source_nodes_torch[source_nodes_torch != 0],
                    "src_ngh_nodes": neighbors_torch[neighbors_torch != 0],
                    "src_ngh_eidx": edge_idxs[edge_idxs != 0],
                    "attn_weight": atten_weights[edge_idxs != 0],
                }
            )
            # import ipdb; ipdb.set_trace()

            return source_embedding

    def aggregate(
        self,
        n_layers,
        source_node_features,
        source_nodes_time_embedding,
        neighbor_embeddings,
        edge_time_embeddings,
        edge_features,
        mask,
    ):
        return None


class GraphSumEmbedding(GraphEmbedding):
    def __init__(
        self,
        node_features,
        edge_features,
        memory,
        neighbor_finder,
        time_encoder,
        n_layers,
        n_node_features,
        n_edge_features,
        n_time_features,
        embedding_dimension,
        device,
        n_heads=2,
        dropout=0.1,
        use_memory=True,
    ):
        super(GraphSumEmbedding, self).__init__(
            node_features=node_features,
            edge_features=edge_features,
            memory=memory,
            neighbor_finder=neighbor_finder,
            time_encoder=time_encoder,
            n_layers=n_layers,
            n_node_features=n_node_features,
            n_edge_features=n_edge_features,
            n_time_features=n_time_features,
            embedding_dimension=embedding_dimension,
            device=device,
            n_heads=n_heads,
            dropout=dropout,
            use_memory=use_memory,
        )
        self.linear_1 = torch.nn.ModuleList(
            [
                torch.nn.Linear(
                    embedding_dimension + n_time_features + n_edge_features,
                    embedding_dimension,
                )
                for _ in range(n_layers)
            ]
        )
        self.linear_2 = torch.nn.ModuleList(
            [
                torch.nn.Linear(
                    embedding_dimension + n_node_features + n_time_features,
                    embedding_dimension,
                )
                for _ in range(n_layers)
            ]
        )

    def aggregate(
        self,
        n_layer,
        source_node_features,
        source_nodes_time_embedding,
        neighbor_embeddings,
        edge_time_embeddings,
        edge_features,
        mask,
    ):
        neighbors_features = torch.cat(
            [neighbor_embeddings, edge_time_embeddings, edge_features], dim=2
        )
        neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
        neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))

        source_features = torch.cat(
            [source_node_features, source_nodes_time_embedding.squeeze()], dim=1
        )
        source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
        source_embedding = self.linear_2[n_layer - 1](source_embedding)

        return source_embedding, None


class GraphAttentionEmbedding(GraphEmbedding):  #! actually use this
    def __init__(
        self,
        node_features,
        edge_features,
        memory,
        neighbor_finder,
        time_encoder,
        n_layers,
        n_node_features,
        n_edge_features,
        n_time_features,
        embedding_dimension,
        device,
        n_heads=2,
        dropout=0.1,
        use_memory=True,
    ):
        super(GraphAttentionEmbedding, self).__init__(
            node_features,
            edge_features,
            memory,
            neighbor_finder,
            time_encoder,
            n_layers,
            n_node_features,
            n_edge_features,
            n_time_features,
            embedding_dimension,
            device,
            n_heads,
            dropout,
            use_memory,
        )

        self.attention_models = torch.nn.ModuleList(
            [
                TemporalAttentionLayer(
                    n_node_features=n_node_features,
                    n_neighbors_features=n_node_features,
                    n_edge_features=n_edge_features,
                    time_dim=n_time_features,
                    n_head=n_heads,
                    dropout=dropout,
                    output_dimension=n_node_features,
                )
                for _ in range(n_layers)
            ]
        )

    def aggregate(
        self,
        n_layer,
        source_node_features,
        source_nodes_time_embedding,
        neighbor_embeddings,
        edge_time_embeddings,
        edge_features,
        mask,
    ):
        attention_model = self.attention_models[n_layer - 1]

        source_embedding, atten_weights = attention_model(
            source_node_features,
            source_nodes_time_embedding,
            neighbor_embeddings,
            edge_time_embeddings,
            edge_features,
            mask,
        )

        return source_embedding, atten_weights


def get_embedding_module(
    module_type,
    node_features,
    edge_features,
    memory,
    neighbor_finder,
    time_encoder,
    n_layers,
    n_node_features,
    n_edge_features,
    n_time_features,
    embedding_dimension,
    device,
    n_heads=2,
    dropout=0.1,
    n_neighbors=None,
    use_memory=True,
):
    if module_type == "graph_attention":
        return GraphAttentionEmbedding(
            node_features=node_features,
            edge_features=edge_features,
            memory=memory,
            neighbor_finder=neighbor_finder,
            time_encoder=time_encoder,
            n_layers=n_layers,
            n_node_features=n_node_features,
            n_edge_features=n_edge_features,
            n_time_features=n_time_features,
            embedding_dimension=embedding_dimension,
            device=device,
            n_heads=n_heads,
            dropout=dropout,
            use_memory=use_memory,
        )
    elif module_type == "graph_sum":
        return GraphSumEmbedding(
            node_features=node_features,
            edge_features=edge_features,
            memory=memory,
            neighbor_finder=neighbor_finder,
            time_encoder=time_encoder,
            n_layers=n_layers,
            n_node_features=n_node_features,
            n_edge_features=n_edge_features,
            n_time_features=n_time_features,
            embedding_dimension=embedding_dimension,
            device=device,
            n_heads=n_heads,
            dropout=dropout,
            use_memory=use_memory,
        )

    elif module_type == "identity":
        return IdentityEmbedding(
            node_features=node_features,
            edge_features=edge_features,
            memory=memory,
            neighbor_finder=neighbor_finder,
            time_encoder=time_encoder,
            n_layers=n_layers,
            n_node_features=n_node_features,
            n_edge_features=n_edge_features,
            n_time_features=n_time_features,
            embedding_dimension=embedding_dimension,
            device=device,
            dropout=dropout,
        )
    elif module_type == "time":
        return TimeEmbedding(
            node_features=node_features,
            edge_features=edge_features,
            memory=memory,
            neighbor_finder=neighbor_finder,
            time_encoder=time_encoder,
            n_layers=n_layers,
            n_node_features=n_node_features,
            n_edge_features=n_edge_features,
            n_time_features=n_time_features,
            embedding_dimension=embedding_dimension,
            device=device,
            dropout=dropout,
            n_neighbors=n_neighbors,
        )
    else:
        raise ValueError("Embedding Module {} not supported".format(module_type))
