import torch
import torch_sparse
from ..utils import reset_slice_dict_edges
from torch_geometric.graphgym.config import cfg


class OutputLayer(torch.nn.Module):
    """
        Implements the final denoising step, mapping the (B x D) x EMBED_DIM input tensor to (B x D) x S where S is the class number.
    """
    def __init__(self, node_dim_in, node_dim_out, edge_dim_in, edge_dim_out):
        super(OutputLayer, self).__init__()
        self.nodes_output_layer = torch.nn.Linear(node_dim_in, node_dim_out)

        if cfg.gt.get("sizing", False):
            self.node_features_output_layer = torch.nn.Sequential(
                torch.nn.Linear(node_dim_in, node_dim_in),
                torch.nn.SiLU(),
                torch.nn.Linear(node_dim_in, 1)
            )

        self.act = torch.nn.ReLU()
        self.fc_edges = torch.nn.Linear(edge_dim_in * 2, edge_dim_in)
        self.edges_output_layer = torch.nn.Linear(edge_dim_in, edge_dim_out)

    def forward(self, batch):

        x = self.nodes_output_layer(batch.x)
        if hasattr(self, 'node_features_output_layer'):
            # The model predicts a velocity field, so batch.x_features is an abuse of notations.
            if cfg.gt.process_feats_with_x:
                batch.x_features = self.node_features_output_layer(batch.x)
            else:
                batch.x_features = self.node_features_output_layer(batch.x_features)
            # # Unnormalize features back to the original range (although this can be < 0 or > nnode_features)
            # batch.x_features = (x_features + 0.35) * 1.7 * cfg.dataset.nnode_features
        batch.x = x

        # Average predictions on both directions of an edge into a single value
        triu_idx = batch.edge_index[0, :] < batch.edge_index[1, :]
        tril_idx = batch.edge_index[0, :] > batch.edge_index[1, :]
        triu_attr, tril_attr = batch.edge_attr[triu_idx], batch.edge_attr[tril_idx]

        # Concat upper and lower triangular attributes
        triu_attr = torch.cat([triu_attr, triu_attr.new_zeros(triu_attr.size())], dim=1)
        tril_attr = torch.cat([tril_attr.new_zeros(tril_attr.size()), tril_attr], dim=1)

        x1_edge_idx, x1_edge_attr = torch_sparse.coalesce(
            torch.cat([batch.edge_index[:, triu_idx], torch.flip(batch.edge_index[:, tril_idx], dims=[0])], dim=1),
            torch.cat([triu_attr, tril_attr], dim=0),
            batch.num_nodes, batch.num_nodes,
            op="sum"
        )

        batch.edge_index = x1_edge_idx
        batch.edge_attr = x1_edge_attr

        batch = reset_slice_dict_edges(batch)

        # Finally process through the last layer
        batch.edge_attr = self.edges_output_layer(self.act(self.fc_edges(batch.edge_attr)))
        return batch
