import torch.nn as nn

from torch_geometric import nn as nng
from cfd.models.base import BaseModel

class GAT(BaseModel):
    def __init__(self, hparams, encoder, decoder):
        super(GAT, self).__init__(hparams, encoder, decoder)

    def _in_layer(self, hparams):
        assert self.size_hidden_layers % hparams['nb_heads'] == 0
        return nng.GATConv(
            in_channels = self.enc_dim,
            out_channels = self.size_hidden_layers // hparams['nb_heads'],
            heads = hparams['nb_heads'],
            concat = True
        )

    def _hidden_layers(self, hparams):
        assert self.size_hidden_layers % hparams['nb_heads'] == 0
        hidden_layers = nn.ModuleList()
        for n in range(self.nb_hidden_layers - 1):
            hidden_layers.append(nng.GATConv(
                in_channels = self.size_hidden_layers,
                out_channels = self.size_hidden_layers // hparams['nb_heads'],
                heads = hparams['nb_heads'],
                concat = True
            ))
        return hidden_layers
        
    def _out_layer(self, hparams):
        assert self.dec_dim % hparams['nb_heads'] == 0
        return nng.GATConv(
                in_channels = self.size_hidden_layers,
                out_channels = self.dec_dim // hparams['nb_heads'],
                heads = hparams['nb_heads'],
                concat = True
            )
