from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential
from torch.utils.data import DataLoader

import torch_geometric.nn as pyg_nn


class DirGAT(pyg_nn.conv.MessagePassing):
    """
    """
    def __init__(
            self,
            dim_in,
            dim_out,
            dropout,
            edge_dim,
            alpha=0.5,
            learn_alpha=True,
            heads=1,
    ):
        super().__init__()
        self.conv_src_to_dst = pyg_nn.GATConv(dim_in, dim_out, heads=heads, dropout=dropout, edge_dim=edge_dim)
        self.conv_dst_to_src = pyg_nn.GATConv(dim_in, dim_out, heads=heads, dropout=dropout, edge_dim=edge_dim)
        if learn_alpha:
            # using a size of 1 might cause a CUDA misalignment error in some cases
            # we use a power of 2 to avoid this and then access the first element
            self.alpha = torch.nn.Parameter(torch.ones(4) * alpha, requires_grad=True)
        else:
            self.alpha = alpha

    def forward(self, batch):
        x_in, e, edge_index = batch.x, batch.edge_attr, batch.edge_index
        edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)

        dst_to_src = self.conv_dst_to_src(x_in, edge_index_t, edge_attr=e)
        src_to_dst = self.conv_src_to_dst(x_in, edge_index, edge_attr=e)

        x = (1. - self.alpha[0]) * src_to_dst + self.alpha[0] * dst_to_src
        return x
