from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential
from torch.utils.data import DataLoader

import torch_geometric.nn as pyg_nn

from dirgt.layer.gine_conv_layer import GINEConvLayer


class DirGINE(pyg_nn.conv.MessagePassing):
    """
    """
    def __init__(
            self,
            dim_in,
            dim_out,
            dropout,
            residual,
            edge_dim,
            alpha=0.5,
            norm_type=None,
    ):
        super().__init__()
        self.conv_src_to_dst = GINEConvLayer(dim_in, dim_out, dropout, residual, edge_dim, norm_type)
        self.conv_dst_to_src = GINEConvLayer(dim_in, dim_out, dropout, residual, edge_dim, norm_type)

        self.alpha = alpha

    def forward(self, batch):
        x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index
        edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)

        batch.edge_index = edge_index_t
        batch_dst_to_src = self.conv_dst_to_src(batch)
        batch.edge_index = edge_index
        batch_src_to_dst = self.conv_src_to_dst(batch)

        batch.x = (1. - self.alpha) * batch_src_to_dst.x + self.alpha * batch_dst_to_src.x
        batch.edge_attr = (1. - self.alpha) * batch_src_to_dst.edge_attr + self.alpha * batch_dst_to_src.edge_attr

        return batch
