import torch
import torch_sparse
from ..utils import reset_slice_dict_edges

class OutputLayerPinPrediction(torch.nn.Module):
    """
        Implements the final denoising step, mapping the (B x D) x EMBED_DIM input tensor to (B x D) x S where S is the class number.
    """
    def __init__(self, edge_dim_in, edge_dim_out):
        super(OutputLayerPinPrediction, self).__init__()
        self.act = torch.nn.ReLU()
        self.fc_edges = torch.nn.Linear(edge_dim_in * 2, edge_dim_in)
        self.edges_output_layer = torch.nn.Linear(edge_dim_in, edge_dim_out)

    def forward(self, batch):

        # Average predictions on both directions of an edge into a single value
        triu_idx = batch.edge_index[0, :] < batch.edge_index[1, :]
        tril_idx = batch.edge_index[0, :] > batch.edge_index[1, :]
        triu_attr, tril_attr = batch.edge_attr[triu_idx], batch.edge_attr[tril_idx]

        # Concat upper and lower triangular attributes
        triu_attr = torch.cat([triu_attr, triu_attr.new_zeros(triu_attr.size())], dim=1)
        tril_attr = torch.cat([tril_attr.new_zeros(tril_attr.size()), tril_attr], dim=1)

        x1_edge_idx, x1_edge_attr = torch_sparse.coalesce(
            torch.cat([batch.edge_index[:, triu_idx], torch.flip(batch.edge_index[:, tril_idx], dims=[0])], dim=1),
            torch.cat([triu_attr, tril_attr], dim=0),
            batch.num_nodes, batch.num_nodes,
            op="sum"
        )

        batch.edge_index = x1_edge_idx
        batch.edge_attr = x1_edge_attr

        batch = reset_slice_dict_edges(batch)

        # Finally process through the last layer
        batch.edge_attr = self.edges_output_layer(self.act(self.fc_edges(batch.edge_attr)))
        return batch
