import math
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.typing import OptTensor
                                    
from torch_geometric.utils import softmax

from modules.emb_module import TransformerConv

#TGN as a torch model
class WeightedTransformerConv(TransformerConv):
    " Extend TransformerConvolution (used as embedding module in TGN)"

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        # segregate weight from edge_attr
        if edge_attr.size(-1) == self.edge_dim:
            # edge weight is absent
            edge_weight = None
        else:
            edge_weight = edge_attr[:,-1]
            edge_attr = edge_attr[:,:-1]
            edge_weight = edge_weight.unsqueeze(1).repeat(1, self.heads) #making it the same shape as alpha i.e. n_e,n_heads

        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        
        if edge_weight is not None:
            # `alpha` unchanged if edge_attr == 1 and -Inf if edge_attr == 0;
            # We choose log to counteract underflow in subsequent exp/softmax
            alpha = alpha + torch.log2(edge_weight)

        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr

        out = out * alpha.view(-1, self.heads, 1)
        return out


class WeightedGraphAttentionEmbedding(torch.nn.Module):

    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = WeightedTransformerConv(
            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
        )
    
    def forward(self, x, last_update, edge_index, edge_weight, t, msg):
        assert edge_weight is None or edge_weight.dim()==1 

        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        if edge_weight is None:
            edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        else:
            edge_attr = torch.cat([rel_t_enc, msg, edge_weight.unsqueeze(1)], dim=-1)
        return self.conv(x, edge_index, edge_attr)

    def reset_parameters(self):
        # self.time_enc.reset_parameters()
        self.conv.reset_parameters()

class WeightedTGN(torch.nn.Module):
    def __init__(self, memory, gnn, link_pred, neighbor_loader, device,use_edge_weight = False):
        super().__init__()
        self.memory = memory
        self.gnn = gnn
        self.link_pred = link_pred
        self.neighbor_loader = neighbor_loader
        self.use_edge_weight = use_edge_weight
        self.device = device 

    def forward(self, data, edge_index, edge_weight, assoc):
        # assert  torch.all(torch.logical_or(edge_weight == 0, edge_weight == 1))

        src = edge_index[0]
        dst = edge_index[1]
        device = src.device #device as var ?

        n_id = torch.cat([src, dst]).unique()
        n_id, n_edge_index, n_e_id = self.neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device) 

        # Get updated memory of all nodes involved in the computation.
        z, last_update = self.memory(n_id)
        edge_weight = data.edge_weights[n_e_id] if edge_weight is not None else None
        z = self.gnn(
            z,
            last_update,
            n_edge_index,
            edge_weight,
            data.t[n_e_id].to(device),
            data.msg[n_e_id].to(device),
        )
        
        y_pred = self.link_pred(z[assoc[src]], z[assoc[dst]])

        return y_pred

        

    
    def train(self):
        self.memory.train()
        self.gnn.train()
        self.link_pred.train()
    
    def eval(self):
        self.memory.eval()
        self.gnn.eval()
        self.link_pred.eval()

    def reset_parameters(self):
        self.memory.reset_parameters()
        self.gnn.reset_parameters()
        self.link_pred.reset_parameters()

    def insert_neighbor(self, src, dst):
        self.neighbor_loader.insert(src, dst)


    def update_memory(self, src, dst, t, msg, edge_weight=None):
        if edge_weight is None:
            edge_weight = torch.ones(len(src),device=self.device)
        
        self.memory.update_state(src, dst, t, msg, edge_weight)
        
    def get_current_edge_index(self):
        return self.neighbor_loader.cur_e_id