from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from AlibiTransformer import ALiBiTransformer
from typing import Tuple, Dict, List
from torch_geometric.nn import GCN, GAT, GraphSAGE, GIN, GCNConv, APPNP, SGConv

class TGModel(nn.Module):
    def __init__(self, timescale: float, num_node: int, msgdim: int, hiddim: int, *transformer_args, **transformer_kwargs) -> None:
        super().__init__()
        self.timescale = timescale
        self.nodeEmb = nn.Embedding(num_node, hiddim, _freeze=True)
        nn.init.normal_(self.nodeEmb.weight)

        self.in_typeEmb = nn.Embedding(3, hiddim) # 0: time, 1: src, 2: dst
        self.out_typeEmb = nn.Embedding(3, hiddim) # 0: time, 1: src, 2: dst
        self.msgEnc = nn.Sequential(nn.Linear(msgdim, msgdim), nn.SiLU(inplace=True), nn.Linear(msgdim, hiddim))
        self.timeEnc = nn.Sequential(nn.Linear(1, hiddim), nn.SiLU(inplace=True), nn.Linear(hiddim, hiddim))
        decdim = int(hiddim**0.5)
        self.timeDec = nn.Sequential(nn.Linear(hiddim, decdim), nn.SiLU(inplace=True), nn.Linear(decdim, 1), nn.Softplus())
        self.nodeDec1 = nn.Sequential(nn.Linear(hiddim, hiddim), nn.SiLU(inplace=True))
        self.nodeDec2 = nn.Sequential(nn.Linear(hiddim, hiddim), nn.SiLU(inplace=True), nn.Linear(hiddim, 1))
        self.transformer = ALiBiTransformer(hiddim, *transformer_args, **transformer_kwargs)
        self.gnn = GAT(in_channels=hiddim, hidden_channels=hiddim, out_channels=hiddim, num_layers=4, dropout=0.1)

    def forward(self, edge_index: Tensor, edge_weight: Tensor, edge_feat: Tensor, sample_node: Tensor, featdict: Dict[str, Tensor], basetime: Tensor, realtime: Tensor, out_node_idx: Tensor=None, out_time_idx: Tensor=None, past_kvts: List[Tuple[Tensor, Tensor, Tensor]]=None, out_cache: bool=False, out_innertimepred: bool=False):
        '''
        featdict:
            tokens (L, ) int
            in_type (L, ) int
            out_type (L, ) int
            timeidx (#time, ) int
            times (#time, ) float
            msgidx (L, ) int
            msgs (#msg, msgdim) float
        basetime (L,) int
        realtime (L,) int
        out_node_idx (#outnodes), int
        out_time_idx (#outtime), int

        return:
        out_node (#outnodes)
        out_time (#outtimes)
        cache_kvts
        inner_timepred
        '''
        node_emb = self.gnn(x = self.nodeEmb.weight, edge_index = edge_index, edge_attr = edge_weight)

        x = node_emb[featdict["tokens"]]
        if "timeidx" in featdict:
            x[featdict["timeidx"]] = self.timeEnc(featdict["times"].unsqueeze(-1).to(torch.float) * (1 / self.timescale))
        if "msgidx" in featdict:
            x += self.msgEnc(featdict["msgs"])[featdict["msgidx"]]
        if "in_type" in featdict:
            x = x * self.in_typeEmb(featdict["in_type"])
        if "out_type" in featdict:
            x = x * self.out_typeEmb(featdict["out_type"])
        realtime = realtime.to(torch.float)*(1/self.timescale)
        x, ret_caches = self.transformer.forward(x, realtime, past_kvts, out_cache)
        sample_nodeEmb = node_emb[sample_node.view(-1)].view(sample_node.shape[0], sample_node.shape[1], -1)
        out_node, out_time = None, None
        if out_node_idx is not None:
            out_node = self.nodeDec2(self.nodeDec1(x[out_node_idx]).unsqueeze(1)*sample_nodeEmb)
        if out_time_idx is not None:
            out_time = self.timeDec(x[out_time_idx]) + realtime[out_time_idx].unsqueeze(-1)
        return out_node, out_time.squeeze(-1)*self.timescale, ret_caches