import torch
import torch.nn as nn
import torch.nn.functional as F

import lib.odefunc as odefunc
import lib.utils as utils


def model_selection(model, dim, nlayer, nunit, **kwargs):
	if model == 'spnn':
		dyn_model = odefunc.ODEfunc_SPNN(dim, kwargs['lE'], kwargs['lS'], kwargs['nE'], kwargs['nS'])
	elif model == 'gfinn':
		dyn_model = odefunc.ODEfunc_GFINN(dim, kwargs['lE'], kwargs['lS'], kwargs['nE'], kwargs['nS'], kwargs['K'])
	elif model == 'gnodev1':
		dyn_model = odefunc.ODEfunc_GNODEv1(dim, kwargs['lE'], kwargs['lS'], kwargs['nE'], kwargs['nS'], kwargs['D1'], kwargs['D2'])
	elif model == 'gnodev2':
		dyn_model = odefunc.ODEfunc_GNODEv2(dim, kwargs['lE'], kwargs['lS'], kwargs['lA'], kwargs['lB'], kwargs['lD'], kwargs['nE'], kwargs['nS'], kwargs['nA'], kwargs['nB'], kwargs['nD'], kwargs['D'], kwargs['C2'])
	print('model type:', model)
	print('total number of model parameters:',utils.count_parameters(dyn_model))
	return dyn_model
		
