import numpy as np
import torch
from torch import nn
from torch_cluster import radius
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.aggr import Aggregation, SumAggregation

from src.modules.torch_modules import GEGLU, FeedForward


class SupernodeGnn(MessagePassing):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        pos_embed: nn.Module,
        act: nn.Module = GEGLU,
        aggr: Aggregation = SumAggregation(),
        ndim: int = 3,
        relative_pos_embed: nn.Module = None,
    ):
        super().__init__(aggr=aggr)
        self.input_proj = FeedForward(
            dim=hidden_dim, input_dim=input_dim, act=act, depth=2
        )
        input_dim_message = hidden_dim * 3 if relative_pos_embed else hidden_dim * 2
        self.message_net = FeedForward(
            dim=input_dim_message, output_dim=hidden_dim, act=act, depth=2
        )
        self.pos_embed = pos_embed(dim=hidden_dim, ndim=ndim)
        if relative_pos_embed:
            self.relative_pos_embed = relative_pos_embed(dim=hidden_dim, ndim=ndim)
        else:
            self.relative_pos_embed = None

    def forward(self, x, pos, edge_index):
        x = self.input_proj(x) + self.pos_embed(pos)
        return self.propagate(edge_index, x=x, pos=pos)

    def message(self, x_i, x_j, pos_i, pos_j):
        if self.relative_pos_embed is not None:
            rel_pos_embed = self.relative_pos_embed(pos_i - pos_j)
            x = torch.cat([x_i, x_j, rel_pos_embed], dim=-1)
        else:
            x = torch.cat([x_i, x_j], dim=-1)
        return self.message_net(x)


class SupernodePooling(nn.Module):
    def __init__(
        self,
        net: MessagePassing,
        supernodes_radius: float,
        supernodes_max_neighbours: int,
        act: nn.Module = GEGLU,
    ):
        super().__init__()
        self.net = net
        self.supernodes_radius = supernodes_radius
        self.supernodes_max_neighbours = supernodes_max_neighbours

    def forward(
        self,
        x,
        pos,
        batch_index,
        supernode_index,
        super_node_batch_index,
    ):
        edges = radius(
            x=pos,
            y=pos[supernode_index],
            r=self.supernodes_radius,
            max_num_neighbors=self.supernodes_max_neighbours,
            batch_x=batch_index,
            batch_y=super_node_batch_index,
        )
        # Change direction from nodes to supernodes
        edges = edges[[1, 0], :]
        # Correct indexing of supernodes
        edges[1, :] = supernode_index[edges[1, :]]
        x = self.net(x, pos, edges)
        return x[supernode_index]
