from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from AlibiTransformer import ALiBiTransformer
from typing import Tuple, Dict, List
from memory_module import TGNMemory, IdentityMessage, LastAggregator


class TGModel(nn.Module):
    def __init__(self, timescale: float, num_node: int, msgdim: int, hiddim: int, *transformer_args, **transformer_kwargs) -> None:
        super().__init__()
        self.timescale = timescale
        self.nodeEmb = nn.Embedding(num_node, hiddim)
        self.in_typeEmb = nn.Embedding(3, hiddim) # 0: time, 1: src, 2: dst
        self.out_typeEmb = nn.Embedding(3, hiddim) # 0: time, 1: src, 2: dst
        self.msgEnc = nn.Sequential(nn.Linear(msgdim, msgdim), nn.SiLU(inplace=True), nn.Linear(msgdim, hiddim))
        self.timeEnc = nn.Sequential(nn.Linear(1, hiddim), nn.SiLU(inplace=True), nn.Linear(hiddim, hiddim))
        decdim = int(hiddim**0.5)
        self.timeDec = nn.Sequential(nn.Linear(hiddim, decdim), nn.SiLU(inplace=True), nn.Linear(decdim, 1), nn.Softplus())
        self.nodeDec1 = nn.Sequential(nn.Linear(hiddim, hiddim), nn.SiLU(inplace=True))
        self.nodeDec2 = nn.Sequential(nn.Linear(hiddim, hiddim), nn.SiLU(inplace=True), nn.Linear(hiddim, 1))
        self.transformer = ALiBiTransformer(hiddim, *transformer_args, **transformer_kwargs)
        
        # memory module
        t_enc_grad = True
        time_dim = 100
        self.memory = TGNMemory(
            num_nodes=num_node,
            raw_msg_dim=msgdim,
            memory_dim=hiddim, 
            time_dim=time_dim,
            message_module=IdentityMessage(msgdim, hiddim, time_dim),
            aggregator_module=LastAggregator(),
            t_enc_grad=t_enc_grad,
        )

    def update_memory(self, src: Tensor, dst: Tensor, t: Tensor, msg: Tensor):
        self.memory.update_state(src, dst, t, msg)

    def detach_memory(self):
        self.memory.detach()

    def reset_memory(self):
        self.memory.reset_state()

    def forward(self, sample_node: Tensor, featdict: Dict[str, Tensor], basetime: Tensor, realtime: Tensor, out_node_idx: Tensor=None, out_time_idx: Tensor=None, past_kvts: List[Tuple[Tensor, Tensor, Tensor]]=None, out_cache: bool=False, out_innertimepred: bool=False):
        '''
        featdict:
            tokens (L, ) int
            in_type (L, ) int
            out_type (L, ) int
            timeidx (#time, ) int
            times (#time, ) float
            msgidx (L, ) int
            msgs (#msg, msgdim) float
        basetime (L,) int
        realtime (L,) int
        out_node_idx (#outnodes), int
        out_time_idx (#outtime), int

        return:
        out_node (#outnodes)
        out_time (#outtimes)
        cache_kvts
        inner_timepred
        '''
        # if use memory
        x, last_update = self.memory(featdict["tokens"])      
        if "timeidx" in featdict:
            x[featdict["timeidx"]] = self.timeEnc(featdict["times"].unsqueeze(-1).to(torch.float) * (1 / self.timescale))
        if "msgidx" in featdict:
            x += self.msgEnc(featdict["msgs"])[featdict["msgidx"]]
        if "in_type" in featdict:
            x = x * self.in_typeEmb(featdict["in_type"])
        if "out_type" in featdict:
            x = x * self.out_typeEmb(featdict["out_type"])
        realtime = realtime.to(torch.float)*(1/self.timescale)
        x, ret_caches = self.transformer.forward(x, realtime, past_kvts, out_cache)
        sample_nodeEmb, _ = self.memory(sample_node.view(-1))
        sample_nodeEmb = sample_nodeEmb.view(sample_node.shape[0], sample_node.shape[1], -1)
        out_node, out_time = None, None
        if out_node_idx is not None:
            out_node = self.nodeDec2(self.nodeDec1(x[out_node_idx]).unsqueeze(1)*sample_nodeEmb)
        if out_time_idx is not None:
            out_time = self.timeDec(x[out_time_idx]) + realtime[out_time_idx].unsqueeze(-1)
        return out_node, out_time.squeeze(-1)*self.timescale, ret_caches