import torch
import torch.nn.functional as F

class TGNLinkPred(torch.nn.Module):
    def __init__(self, memory, gnn, link_pred, neighbor_loader):
        super().__init__()
        self.memory = memory
        self.gnn = gnn
        self.link_pred = link_pred
        self.neighbor_loader = neighbor_loader

    def forward(self, data, src, dst, assoc):
        
        device = src.device 

        n_id = torch.cat([src, dst]).unique()
        n_id, n_edge_index, n_e_id = self.neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device) 

        # Get updated memory of all nodes involved in the computation.
        z, last_update = self.memory(n_id)
        
        z = self.gnn(
            z,
            last_update,
            n_edge_index,
            data.t[n_e_id].to(device),
            data.msg[n_e_id].to(device),
        )
        
        y_pred = self.link_pred(z[assoc[src]], z[assoc[dst]])

        return y_pred     

    
    def train(self):
        self.memory.train()
        self.gnn.train()
        self.link_pred.train()
    
    def eval(self):
        self.memory.eval()
        self.gnn.eval()
        self.link_pred.eval()

    def reset_parameters(self):
        self.memory.reset_parameters()
        self.gnn.reset_parameters()
        self.link_pred.reset_parameters()

    def insert_neighbor(self, src, dst):
        self.neighbor_loader.insert(src, dst)


    def update_memory(self, src, dst, t, msg):
        self.memory.update_state(src, dst, t, msg)


    def get_current_edge_index(self):
        return self.neighbor_loader.cur_e_id