import torch
from torch_scatter import scatter_sum, scatter_max
from torch_geometric.nn import MLP, MessagePassing
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import add_remaining_self_loops, to_undirected, to_dense_batch, degree, remove_self_loops


class SAGEConv(MessagePassing):
    def __init__(self, hid_dim, edge_encode_layers, num_mlp_layers, act, norm, aggr):
        super(SAGEConv, self).__init__(aggr=aggr)

        self.act = activation_resolver(act)
        self.lin_src = torch.nn.Linear(hid_dim, hid_dim)
        self.lin_msg = MLP([1] + [hid_dim] * edge_encode_layers, act=act, norm=norm)
        self.mlp = MLP([hid_dim] * (num_mlp_layers + 1), act=act, norm=norm, plain_last=False)

    def reset_parameters(self):
        self.lin_dst.reset_parameters()
        self.lin_msg.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.lin_src(x)
        msg = self.lin_msg(edge_attr, batch[edge_index[0]])
        out = self.propagate(edge_index, x=x, edge_attr=msg)
        out = out + x
        return self.mlp(out, batch)

    def message(self, x_j, edge_attr):
        return self.act(x_j + edge_attr)


class UnsupervisedGNN(torch.nn.Module):
    def __init__(self,
                 hid_dim,
                 num_encode_layers,
                 num_conv_layers,
                 edge_encode_layers,
                 gnn_mlp_layers,
                 num_pred_layers,
                 aggr,
                 square_dist):
        super().__init__()
        self.vals_encoder = MLP([1] + [hid_dim] * num_encode_layers, act='relu', norm=None)
        self.num_conv_layers = num_conv_layers
        self.square_dist = square_dist

        self.gcns = None
        self.gcns = torch.nn.ModuleList()
        for layer in range(num_conv_layers):
            self.gcns.append(SAGEConv(hid_dim=hid_dim,
                                      edge_encode_layers=edge_encode_layers,
                                      num_mlp_layers=gnn_mlp_layers,
                                      act='relu', norm='graphnorm', aggr=aggr))

        self.predictor = MLP([hid_dim] * num_pred_layers + [1], act='relu', norm=None)

    def forward(self, data):
        x = self.vals_encoder(torch.ones(data.num_nodes, 1, device=data.edge_index.device, dtype=torch.float))

        for i in range(self.num_conv_layers):
            x = self.gcns[i](x, data.edge_index, data.edge_weight[:, None], data.batch)

        pred = self.predictor(x).squeeze()
        pred = torch.sigmoid(pred)

        # loss
        loss1 = scatter_sum(pred, data.batch)

        # add self-loop
        edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight, fill_value=0.,
                                                           num_nodes=data.num_nodes)
        edge_index, edge_weight = to_undirected(edge_index, edge_weight, reduce='mean')

        loss2 = scatter_sum((1 - pred + 1.e-5)[edge_index[0]].log(), edge_index[1]).exp()
        loss2 = scatter_sum(loss2, data.batch)

        # third loss
        radii, mask = to_dense_batch(edge_weight, edge_index[0], fill_value=0.)
        maxdeg = radii.shape[1]

        # this might be dangerous, as there might be a tie, if the nodes distances are the same
        # smaller_mask = radii[:, :, None] > radii[:, None, :]

        # break the tie with indices
        dist_strictly_smaller = radii[:, :, None] > radii[:, None, :]
        indices = torch.arange(radii.size(1), device=radii.device)
        index_smaller = indices[None, :, None] > indices[None, None, :]
        dist_equal = (radii[:, :, None] == radii[:, None, :])
        # Combine: z is closer if (dist_z < dist_v) OR (dist_equal AND index_z < index_v)
        smaller_mask = dist_strictly_smaller | (dist_equal & index_smaller)

        smaller_mask[~mask[:, None, :] | ~mask[:, :, None]] = False
        p_repeat = pred[edge_index[1]]
        p_repeat, _ = to_dense_batch(p_repeat, edge_index[0], fill_value=0.)
        p_repeat[~mask] = 0.
        complement_p = (1 - pred)[edge_index[1]]
        complement_p, _ = to_dense_batch(complement_p, edge_index[0], fill_value=1.)
        complement_p_repeat = complement_p[:, None, :].repeat(1, maxdeg, 1)
        prod_neighbors = (complement_p_repeat + 1.e-5).log() * smaller_mask.float()
        prod_neighbors = torch.exp(prod_neighbors.sum(2))
        if self.square_dist:
            radii = radii ** 2
        prod = p_repeat * prod_neighbors * radii
        loss3 = prod.sum(1)
        loss3 = scatter_sum(loss3, data.batch)
        return loss1, loss2, loss3

    @torch.no_grad()
    def predict(self, data):
        x = self.vals_encoder(torch.ones(data.num_nodes, 1, device=data.edge_index.device, dtype=torch.float))

        for i in range(self.num_conv_layers):
            x = self.gcns[i](x, data.edge_index, data.edge_weight[:, None], data.batch)

        pred = self.predictor(x).squeeze()
        pred = torch.sigmoid(pred)
        return pred
