import torch
import torch.nn as nn

from torch_geometric.nn.models import DimeNetPlusPlus as DimeNet
from .baselines.schnet import SchNetModel as SchNet
from .baselines.comenet import ComENetModel as ComENet
from .baselines.egnn import EGNNModel as EGNN

from register import MODEL_REGISTRY

class SchNetWrapper(nn.Module):
    def __init__(self, model_config):
        super(SchNetWrapper, self).__init__()
        self.model = SchNet(
            hidden_channels = model_config['hidden_channels'],
            in_dim = model_config['in_dim'],
            out_dim = model_config['out_dim'],
            num_filters = model_config['num_filters'],
            num_layers = model_config['num_layers'],
            num_gaussians = model_config['num_gaussians'],
            cutoff = model_config['cutoff'],
            node_prediction = model_config['node_prediction'],
            break_symmetry = model_config['break_symmetry'],
        )

    def forward(self, data_obj):
         return self.model(data_obj)

class DimeNetWrapper(nn.Module):
    def __init__(self, model_config):
        super(DimeNetWrapper, self).__init__()
        self.model = DimeNet(
            hidden_channels = model_config['hidden_channels'],
            out_channels = model_config['out_channels'],
            num_blocks = model_config['num_blocks'],
            int_emb_size = model_config['int_emb_size'],
            basis_emb_size = model_config['basis_emb_size'],
            out_emb_channels = model_config['out_emb_channels'],
            num_spherical = model_config['num_spherical'],
            num_radial = model_config['num_radial'],
            cutoff = model_config['cutoff'],
            max_num_neighbors = model_config['max_num_neighbors'],
            num_before_skip = model_config['num_before_skip'],
            num_after_skip = model_config['num_after_skip'],
            num_output_layers = model_config['num_output_layers'],
        )
    def forward(self, data_obj):
        return self.model(data_obj.z, data_obj.pos, data_obj.batch)

class ComENetWrapper(nn.Module):
    def __init__(self, model_config):
        super(ComENetWrapper, self).__init__()
        self.model = ComENet(
                in_dim = model_config['in_dim'],
                out_dim = model_config['out_dim'],
                cutoff = model_config['cutoff'],
                num_layers = model_config['num_layers'],
                hidden_channels = model_config['hidden_channels'],
                middle_channels = model_config['middle_channels'],
                num_radial = model_config['num_radial'],
                num_spherical = model_config['num_spherical'],
                num_output_layers = model_config['num_output_layers'],
        )

    def forward(self, data_obj):
        return self.model(data_obj)

class EGNNWrapper(nn.Module):
    def __init__(self, model_config):
        super(EGNNWrapper, self).__init__()
        self.model = EGNN(
                num_layers = model_config['num_layers'],
                emb_dim = model_config['emb_dim'],
                in_dim = model_config['in_dim'],
                out_dim = model_config['out_dim'],
                activation = model_config['act'],
                norm = model_config['norm'],
                aggr = model_config['aggr'],
                pool = model_config['pool'],
                residual = model_config['residual'],
                equivariant_pred = model_config['equivariant_pred'],
                break_symmetry = model_config['break_symmetry'],
        )
    def forward(self, data_obj):
        return self.model(data_obj)


MODEL_REGISTRY.register('schnet', SchNetWrapper)
MODEL_REGISTRY.register('dimenet', DimeNetWrapper)
MODEL_REGISTRY.register('comenet', ComENetWrapper)
MODEL_REGISTRY.register('egnn', EGNNWrapper)
