from typing import Optional, Tuple
from torch import Tensor

from ..data.data_loader import ComputationGraph
from ..data.graph import Graph
from .basic_modules import MergeLayer
from .feature_getter import FeatureGetter
from .time_encoding import TimeEncode
from .utils import anonymized_reindex
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
import math


class Fusion(nn.Module):
    def __init__(self, dim_prom, dim_mem, dim_hid, dim_out):
        super().__init__()
        self.fc1 = torch.nn.Linear(dim_prom + dim_mem, dim_hid)
        self.fc2 = torch.nn.Linear(dim_hid, dim_out)
        self.act = torch.nn.Sigmoid()

    def forward(self, prom, mem):
        x = torch.cat([prom, mem], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)


class TProG(nn.Module):
    def __init__(self, raw_feat_getter: FeatureGetter, graph: Graph, prompt_dim=None):
        super().__init__()
        self.raw_feat_getter = raw_feat_getter
        self.graph = graph

        self.n_nodes = self.raw_feat_getter.n_nodes
        self.nfeat_dim = self.raw_feat_getter.nfeat_dim
        self.efeat_dim = self.raw_feat_getter.efeat_dim
        if prompt_dim is None:
            self.prompt_dim = self.nfeat_dim
        else:
            self.prompt_dim = prompt_dim

        self.time_encoder = TimeEncode(dim=self.nfeat_dim)
        self.tfeat_dim = self.time_encoder.dim

    def forward(self, nids: Tensor, ts: Tensor, emb: Tensor, src_or_dst_ids: Tensor,
                computation_graph: Optional[ComputationGraph]=None
               ) -> Tuple[Tensor, Tensor, Tensor]:
        raise NotImplementedError


class TransformerTProG(TProG):
    def __init__(self, raw_feat_getter: FeatureGetter, graph: Graph, prompt_dim=None,
                 *, hist_len: int=20, n_head=2, dropout=0.1, dyrep=False):
        super().__init__(raw_feat_getter, graph, prompt_dim)

        self.hist_len = hist_len

        self.anony_emb = nn.Embedding(self.hist_len + 1, self.nfeat_dim)

        self.d_model = self.nfeat_dim * 3 + self.efeat_dim + self.tfeat_dim
        self.mha_fn = nn.MultiheadAttention(self.d_model, n_head, dropout)
        self.out_fn = nn.Linear(self.d_model, self.prompt_dim)
        self.merger = MergeLayer(self.prompt_dim, self.d_model - self.tfeat_dim,
                                 self.nfeat_dim, self.prompt_dim, dropout=dropout)

        self.fc1 = torch.nn.Linear(self.nfeat_dim+self.prompt_dim, self.nfeat_dim)
        self.fc2 = torch.nn.Linear(self.nfeat_dim, self.nfeat_dim)
        self.act = torch.nn.Sigmoid()
        self.dyrep = dyrep

    def forward(self, nids: Tensor, ts: Tensor, emb: Tensor, src_or_dst_ids: Tensor,
                computation_graph: Optional[ComputationGraph]=None,
                is_neg: Optional[bool] = False
               ) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Compute surrogate representations h(t'-) and h(t'+).
        -----
        nids: node ids
        ts: the current timestamps t
        computation_graph: computation graph containing necessary information
                           This is only given during training.
        -----
        returns: h_prev_left, h_prev_right, prev_ts
        h_prev_left: h(t'-)
        h_prev_right: h(t'+)
        prev_ts: t'
        """
        if computation_graph is None:
            device = nids.device
            hist_nids, hist_eids, hist_ts, hist_dirs = self.graph.get_history(
                nids.cpu().numpy(), ts.cpu().numpy(), self.hist_len)
            anonymized_ids = anonymized_reindex(hist_nids)
            hist_nids = torch.from_numpy(hist_nids).to(device).long()  # [bs, len]
            anonymized_ids = torch.from_numpy(anonymized_ids).to(device).long()
            hist_eids = torch.from_numpy(hist_eids).to(device).long()
            hist_ts = torch.from_numpy(hist_ts).to(device).float()
            hist_dirs = torch.from_numpy(hist_dirs).to(device).long()
        else:
            device = nids.device
            if is_neg:
                data = computation_graph.restart_data_neg
            else:
                data = computation_graph.restart_data
            hist_nids = data.hist_nids.to(device)
            anonymized_ids = data.anonymized_ids.to(device)
            hist_eids = data.hist_eids.to(device)
            hist_ts = data.hist_ts.to(device)
            hist_dirs = data.hist_dirs.to(device)

        bs, hist_len = hist_nids.shape
        mask = (hist_nids == 0)  # [bs, len]
        mask[:, -1] = False  # to avoid bugs
        invalid_rows = mask.all(1, keepdims=True)  # [n, 1]

        # event reprs = [src, dst, edge, anony, ts]
        # dirs is used to determine if the current node is src or dst
        r_nids = nids.unsqueeze(1).repeat(1, hist_len)
        src_nids = r_nids * hist_dirs + hist_nids * (1-hist_dirs)
        dst_nids = r_nids * (1-hist_dirs) + hist_nids * hist_dirs

        src_vals = self.raw_feat_getter.get_node_embeddings(src_nids)
        dst_vals = self.raw_feat_getter.get_node_embeddings(dst_nids)
        edge_vals = self.raw_feat_getter.get_edge_embeddings(hist_eids)
        anony_vals = self.anony_emb(anonymized_ids)
        ts_vals = self.time_encoder(hist_ts[:, -1].unsqueeze(1) - hist_ts)
        full_vals = torch.cat([src_vals, dst_vals, anony_vals, edge_vals, ts_vals], 2)  # [bs, len, D]

        last_event_feat = full_vals[:, -1, :self.d_model - self.tfeat_dim]
        full_vals[:, -1, :self.d_model - self.tfeat_dim] = 0.  # only keep time feats
        qkv = full_vals.transpose(0, 1)  # [len, bs, D]
        out, _ = self.mha_fn(qkv, qkv, qkv, key_padding_mask=mask)
        # h(t'-)
        h_prev_left = self.out_fn(F.relu(out.mean(0)))  # [bs, D]  mean aggregate
        # h_prev = self.out_fn(F.relu(out[-1, :, :]))  # [bs, D] last
        h_prev_right = self.merger(h_prev_left, last_event_feat)  # h(t'+)
        h_prev_left = h_prev_left.masked_fill(invalid_rows, 0.)
        h_prev_right = h_prev_right.masked_fill(invalid_rows, 0.)
        prev_ts = hist_ts[:, -1]

        if self.dyrep:
            prompt = h_prev_right[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]
        else:
            prompt = h_prev_left[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]

        x = torch.cat([prompt, emb], dim=1)
        h = self.act(self.fc1(x))

        return h_prev_left, h_prev_right, self.fc2(h)


class VanillaTProG(TProG):
    def __init__(self, raw_feat_getter: FeatureGetter, graph: Graph, dyrep=False, prompt_dim=None):
        super().__init__(raw_feat_getter, graph, prompt_dim)
        self.left_emb = nn.Embedding(self.n_nodes, self.prompt_dim)
        self.right_emb = nn.Embedding(self.n_nodes, self.prompt_dim)
        nn.init.zeros_(self.left_emb.weight)
        nn.init.zeros_(self.right_emb.weight)

        self.fc1 = torch.nn.Linear(self.nfeat_dim+self.prompt_dim, self.nfeat_dim)
        self.fc2 = torch.nn.Linear(self.nfeat_dim, self.nfeat_dim)
        self.act = torch.nn.Sigmoid()

        self.dyrep = dyrep

    def forward(self, nids: Tensor, ts: Tensor, emb: Tensor, src_or_dst_ids: Tensor,
                computation_graph: Optional[ComputationGraph]=None,
                is_neg: Optional[bool] = False
               ) -> Tuple[Tensor, Tensor, Tensor]:
        if computation_graph is None:
            device = nids.device
            _, _, prev_ts, _ = self.graph.get_history(
                nids.cpu().numpy(), ts.cpu().numpy(), 1)
            prev_ts = prev_ts[:, 0]
            prev_ts = torch.from_numpy(prev_ts).to(device).float()
        else:
            device = nids.device
            if is_neg:
                data = computation_graph.restart_data_neg
            else:
                data = computation_graph.restart_data
            # prev_ts = data.prev_ts
        h_left = self.left_emb(nids)
        h_right = self.right_emb(nids)

        if self.dyrep:
            prompt = h_right[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]
        else:
            prompt = h_left[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]

        x = torch.cat([prompt, emb], dim=1)
        h = self.act(self.fc1(x))

        return h_left, h_right, self.fc2(h)


class NormalLinear(nn.Linear):
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.normal_(0, stdv)
        if self.bias is not None:
            self.bias.data.normal_(0, stdv)


class ProjectionTProG(TProG):
    def __init__(self, raw_feat_getter: FeatureGetter, graph: Graph, dyrep=False, prompt_dim=None):
        super().__init__(raw_feat_getter, graph, prompt_dim)
        self.left_emb = nn.Embedding(self.n_nodes, self.prompt_dim)
        self.right_emb = nn.Embedding(self.n_nodes, self.prompt_dim)
        nn.init.zeros_(self.left_emb.weight)
        nn.init.zeros_(self.right_emb.weight)

        self.fc1 = torch.nn.Linear(self.nfeat_dim+self.prompt_dim, self.nfeat_dim)
        self.fc2 = torch.nn.Linear(self.nfeat_dim, self.nfeat_dim)
        self.act = torch.nn.Sigmoid()

        self.dyrep = dyrep

        self.embedding_layer = NormalLinear(self.nfeat_dim, self.nfeat_dim)
        self.fc3 = torch.nn.Linear(self.nfeat_dim+self.prompt_dim, self.nfeat_dim)
        self.fc4 = torch.nn.Linear(self.nfeat_dim, self.prompt_dim)

    def context_convert(self, embeddings, timediffs):
        # new_embeddings = embeddings * (1 + self.embedding_layer(timediffs))
        h1 = torch.cat([self.embedding_layer(timediffs), embeddings], dim=1)
        h2 = self.act(self.fc3(h1))
        new_embeddings = self.fc4(h2)
        return new_embeddings

    def forward(self, nids: Tensor, ts: Tensor, emb: Tensor, src_or_dst_ids: Tensor,
                computation_graph: Optional[ComputationGraph] = None,
                is_neg: Optional[bool] = False
                ) -> Tuple[Tensor, Tensor, Tensor]:
        if computation_graph is None:
            device = nids.device
            _, _, prev_ts, _ = self.graph.get_history(
                nids.cpu().numpy(), ts.cpu().numpy(), 1)
            prev_ts = prev_ts[:, 0]
            prev_ts = torch.from_numpy(prev_ts).to(device).float()
        else:
            device = nids.device
            if is_neg:
                data = computation_graph.restart_data_neg
            else:
                data = computation_graph.restart_data
            prev_ts = data.prev_ts
        h_left = self.left_emb(nids)
        h_right = self.right_emb(nids)
        ts = ts.view(prev_ts.size())
        prev_ts = prev_ts.to(device)
        timediffs = ts - prev_ts
        timediffs = timediffs.view(-1, 1)

        timediffs = self.time_encoder(timediffs)
        timediffs = timediffs.squeeze(1)

        h_left_projected = self.context_convert(h_left, timediffs)
        h_right_projected = self.context_convert(h_right, timediffs)

        if self.dyrep:
            prompt = h_right_projected[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]
        else:
            prompt = h_left_projected[np.searchsorted(nids.cpu(), src_or_dst_ids.cpu()).to(device)]

        x = torch.cat([prompt, emb], dim=1)
        h = self.act(self.fc1(x))

        return h_left_projected, h_right_projected, self.fc2(h)


class OneVecTProG(TProG):
    def __init__(self, raw_feat_getter: FeatureGetter, graph: Graph, dyrep=False, prompt_dim=None):
        super().__init__(raw_feat_getter, graph, prompt_dim)
        self.left_prompt = nn.Parameter(torch.zeros(1, self.prompt_dim), requires_grad=True)
        self.right_prompt = nn.Parameter(torch.zeros(1, self.prompt_dim), requires_grad=True)

        self.fc1 = torch.nn.Linear(self.nfeat_dim + self.prompt_dim, self.nfeat_dim)
        self.fc2 = torch.nn.Linear(self.nfeat_dim, self.nfeat_dim)
        self.act = torch.nn.Sigmoid()

        self.dyrep = dyrep

    def forward(self, nids: Tensor, ts: Tensor, emb: Tensor, src_or_dst_ids: Tensor,
                computation_graph: Optional[ComputationGraph] = None,
                is_neg: Optional[bool] = False
                ) -> Tuple[Tensor, Tensor, Tensor]:
        if computation_graph is None:
            device = nids.device
            _, _, prev_ts, _ = self.graph.get_history(
                nids.cpu().numpy(), ts.cpu().numpy(), 1)
            prev_ts = prev_ts[:, 0]
            prev_ts = torch.from_numpy(prev_ts).to(device).float()
        else:
            device = nids.device
            if is_neg:
                data = computation_graph.restart_data_neg
            else:
                data = computation_graph.restart_data
            # prev_ts = data.prev_ts
        h_left = self.left_prompt
        h_right = self.right_prompt

        bs = emb.size(0)

        if self.dyrep:
            prompt = h_right.expand(bs, -1)
        else:
            prompt = h_left.expand(bs, -1)

        x = torch.cat([prompt, emb], dim=1)
        h = self.act(self.fc1(x))

        return h_left, h_right, self.fc2(h)