r"""
The Graph Neural Network from the `"Neural Message Passing for Quantum Chemistry"
<https://proceedings.mlr.press/v70/gilmer17a.html>`_ paper.
"""
import torch
import torch.nn as nn
from torch import Tensor
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseGNN import GNNBasic
from .Classifiers import Classifier
from .GINs import GINEncoder, GINMolEncoder, GINFeatExtractor
from .Pooling import GlobalAddPool


@register.model_register
class vGIN(GNNBasic):
    r"""
        The Graph Neural Network from the `"Neural Message Passing for Quantum Chemistry"
        <https://proceedings.mlr.press/v70/gilmer17a.html>`_ paper.

        Args:
            config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.dataset.dim_node`, :obj:`config.dataset.num_classes`, :obj:`config.dataset.dataset_type`, :obj:`config.model.dropout_rate`)
    """

    def __init__(self, config: Union[CommonArgs, Munch]):
        super(vGIN, self).__init__(config)
        self.feat_encoder = vGINFeatExtractor(config)
        self.classifier = Classifier(config)
        self.graph_repr = None

    def forward(self, *args, **kwargs) -> torch.Tensor:
        r"""
        The vGIN model implementation.

        Args:
            *args (list): argument list for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`
            **kwargs (dict): key word arguments for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`

        Returns (Tensor):
            label predictions

        """
        out_readout = self.feat_encoder(*args, **kwargs)

        out = self.classifier(out_readout)
        return out


class vGINFeatExtractor(GINFeatExtractor):
    r"""
        vGIN feature extractor using the :class:`~vGINEncoder` or :class:`~vGINMolEncoder`.

        Args:
            config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.dataset.dim_node`, :obj:`config.dataset.dataset_type`, :obj:`config.model.dropout_rate`)
            **kwargs: `without_readout` will output node features instead of graph features.
    """
    def __init__(self, config: Union[CommonArgs, Munch], **kwargs):
        super(vGINFeatExtractor, self).__init__(config)
        print("#D#Init vGINFeatExtractor")
        num_layer = config.model.model_layer
        if config.dataset.dataset_type == 'mol':
            self.encoder = vGINMolEncoder(config, **kwargs)
            self.edge_feat = True
        else:
            self.encoder = vGINEncoder(config, **kwargs)
            self.edge_feat = False


class VirtualNodeEncoder(torch.nn.Module):
    r"""
        The virtual node feature encoder for vGIN.

        Args:
            config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.dropout_rate`)
    """
    def __init__(self, config: Union[CommonArgs, Munch], **kwargs):
        super(VirtualNodeEncoder, self).__init__()
        self.virtual_node_embedding = nn.Embedding(1, config.model.dim_hidden)
        self.virtual_mlp = nn.Sequential(*(
                [nn.Linear(config.model.dim_hidden, 2 * config.model.dim_hidden),
                 nn.BatchNorm1d(2 * config.model.dim_hidden), nn.ReLU()] +
                [nn.Linear(2 * config.model.dim_hidden, config.model.dim_hidden),
                 nn.BatchNorm1d(config.model.dim_hidden), nn.ReLU(),
                 nn.Dropout(config.model.dropout_rate)]
        ))
        self.virtual_pool = GlobalAddPool()


class vGINEncoder(GINEncoder, VirtualNodeEncoder):
    r"""
    The vGIN encoder for non-molecule data, using the :class:`~vGINConv` operator for message passing.

    Args:
        config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.dataset.dim_node`, :obj:`config.model.dropout_rate`)
    """

    def __init__(self, config: Union[CommonArgs, Munch], **kwargs):
        super(vGINEncoder, self).__init__(config, **kwargs)
        self.config = config
        self.without_readout = kwargs.get('without_readout')
        self.mitigation_virtual = kwargs.get('mitigation_virtual')
        print("mitigation_virtual in vGINEncoder = ", self.mitigation_virtual)

    def forward(self, x, edge_index, batch, batch_size, **kwargs):
        r"""
        The vGIN encoder for non-molecule data.

        Args:
            x (Tensor): node features
            edge_index (Tensor): edge indices
            batch (Tensor): batch indicator
            batch_size (int): Batch size.

        Returns (Tensor):
            graph feature representations
        """
        node_repr = self.get_node_repr(x, edge_index, batch, batch_size, **kwargs)

        if self.without_readout or kwargs.get('without_readout'):
            return node_repr
        out_readout = self.readout(node_repr, batch, batch_size, edge_index=edge_index, edge_mask=self.convs[0].__edge_mask__)
        return out_readout

    def get_node_repr(self, x, edge_index, batch, batch_size, **kwargs):
        r"""
        The vGIN encoder for non-molecule data.

        Args:
            x (Tensor): node features
            edge_index (Tensor): edge indices
            batch (Tensor): batch indicator
            batch_size (int): batch size

        Returns (Tensor):
            node feature representations
        """

        virtual_node_feat = [self.virtual_node_embedding(
            torch.zeros(batch[-1].item() + 1, device=self.config.device, dtype=torch.long))]

        layer_feat = [x]
        for i, (conv, batch_norm, relu, dropout) in enumerate(
                zip(self.convs, self.batch_norms, self.relus, self.dropouts)):
            # --- Add global info ---
            if i > 0:
                post_conv = layer_feat[-1] + virtual_node_feat[-1][batch]
            else:
                post_conv = layer_feat[-1]
            post_conv = batch_norm(conv(post_conv, edge_index))
            if i < len(self.convs) - 1:
                post_conv = relu(post_conv)
            layer_feat.append(dropout(post_conv))
            # --- update global info ---
            if 0 < i < len(self.convs) - 1:
                if self.mitigation_virtual == "weighted":
                    pool = self.virtual_pool(layer_feat[-1], batch, None, edge_index=edge_index, edge_mask=self.convs[0].__edge_mask__)
                else:
                    pool = self.virtual_pool(layer_feat[-1], batch, None)
                
                virtual_node_feat.append(self.virtual_mlp(pool + virtual_node_feat[-1]))
        return layer_feat[-1]



class vGINMolEncoder(GINMolEncoder, VirtualNodeEncoder):
    r"""The vGIN encoder for molecule data, using the :class:`~vGINEConv` operator for message passing.

        Args:
            config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.model.dropout_rate`)
    """

    def __init__(self, config: Union[CommonArgs, Munch], **kwargs):
        super(vGINMolEncoder, self).__init__(config, **kwargs)
        self.config: Union[CommonArgs, Munch] = config
        self.without_readout = kwargs.get('without_readout')
        self.mitigation_virtual = kwargs.get('mitigation_virtual')
        print("mitigation_virtual in vGINMolEncoder = ", self.mitigation_virtual)

    def forward(self, x, edge_index, edge_attr, batch, batch_size, **kwargs):
        r"""
        The vGIN encoder for molecule data.

        Args:
            x (Tensor): node features
            edge_index (Tensor): edge indices
            edge_attr (Tensor): edge attributes
            batch (Tensor): batch indicator
            batch_size (int): Batch size.

        Returns (Tensor):
            graph feature representations
        """
        node_repr = self.get_node_repr(x, edge_index, edge_attr, batch, batch_size, **kwargs)

        if self.without_readout or kwargs.get('without_readout'):
            return node_repr
        out_readout = self.readout(node_repr, batch, batch_size, edge_index=edge_index, edge_mask=self.convs[0].__edge_mask__)
        return out_readout

    def get_node_repr(self, x, edge_index, edge_attr, batch, batch_size, **kwargs):
        r"""
        The vGIN encoder for molecule data.

        Args:
            x (Tensor): node features
            edge_index (Tensor): edge indices
            edge_attr (Tensor): edge attributes
            batch (Tensor): batch indicator
            batch_size (int): Batch size.

        Returns (Tensor):
            node feature representations
        """
        virtual_node_feat = [self.virtual_node_embedding(
            torch.zeros(batch[-1].item() + 1, device=self.config.device, dtype=torch.long))]

        layer_feat = [self.atom_encoder(x)]
        for i, (conv, batch_norm, relu, dropout) in enumerate(
                zip(self.convs, self.batch_norms, self.relus, self.dropouts)):
            
            # --- Add global info ---
            post_conv = layer_feat[-1] + virtual_node_feat[-1][batch]
            post_conv = batch_norm(conv(post_conv, edge_index, edge_attr))
            if i < len(self.convs) - 1:
                post_conv = relu(post_conv)
            layer_feat.append(dropout(post_conv))

            # --- update global info ---
            if i < len(self.convs) - 1:
                if self.mitigation_virtual == "weighted": #hasattr(self.virtual_pool, "mitigation_virtual")
                    pool = self.virtual_pool(layer_feat[-1], batch, None, edge_index=edge_index, edge_mask=self.convs[0].__edge_mask__)
                else:
                    pool = self.virtual_pool(layer_feat[-1], batch, None)

                virtual_node_feat.append(self.virtual_mlp(pool + virtual_node_feat[-1]))
        return layer_feat[-1]
