import torch
from torch import nn
from torch_geometric import nn as nng

from cfd.models.base import BaseModel
from cfd.models.utils import MLP


class Conv(nng.MessagePassing):
    def __init__(self, hparams, kernel):
        super(Conv, self).__init__(aggr='mean')
        self.size_hidden_layers = hparams['size_hidden_layers']
        self.kernel = kernel     

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        conv = self.kernel(edge_attr).view(-1, self.size_hidden_layers, self.size_hidden_layers)
        out = torch.matmul(conv, x_j.unsqueeze(-1)).squeeze(-1)
        return out

class GNO(BaseModel):
    def __init__(self, hparams, encoder, decoder):
        super(GNO, self).__init__(hparams, encoder, decoder)
        assert self.enc_dim == self.dec_dim
        assert self.enc_dim == self.size_hidden_layers
        assert self.enc_dim * self.dec_dim == hparams['kernel'][-1]
    
    def _in_layer(self, hparams):
        kernel = MLP(hparams['kernel'], batch_norm=False)
        self.conv = Conv(hparams, kernel)
        return self.conv
    
    def _hidden_layers(self, hparams):
        hidden_layers = nn.ModuleList()
        for n in range(self.nb_hidden_layers - 1):
            hidden_layers.append(self.conv)
        return hidden_layers

    def _out_layer(self, hparams):
        return self.conv

    def get_edge_attr(self, x, edge_index):
        x_clone = x.clone()
        x_i, x_j = x_clone[edge_index[0], 0:3], x_clone[edge_index[1], 0:3]
        sdf_i, sdf_j = x_clone[edge_index[0], 3:4], x_clone[edge_index[1], 3:4]
        normal_i, normal_j = x_clone[edge_index[0], 4:7], x_clone[edge_index[1], 4:7]
        return torch.cat([x_i - x_j, sdf_i, sdf_j, normal_i, normal_j], dim=1)
 