from .g_cnn import Gcnn

def get_model(params):
    if params.model == 'g_cnn':
        model = Gcnn(num_channels=params.num_channels,
                        num_layers=params.num_layers,
                        kernel_sizes=params.kernel_sizes,
                        num_classes=params.num_classes,
                        dwn_group_types=params.dwn_group_types,
                        dwn_orders=params.dwn_orders,
                        spatial_subsampling_factors=params.spatial_subsampling_factors,
                        subsampling_factors=params.subsampling_factors,
                        domain=params.domain,
                        pooling_type=params.pooling_type,
                        apply_antialiasing=params.apply_antialiasing,
                        cannonicalize=params.cannonicalize,
                        dropout_rate=params.dropout_rate,
                        antialiasing_kwargs={},
                        layer_kwargs=params.layer_kwargs,
                        fully_convolutional=params.fully_convolutional
        )
    elif params.model == 'g_cnn_dwn':
        model = Gcnn(num_channels=params.num_channels,
                        num_layers=params.num_layers,
                        kernel_sizes=params.kernel_sizes,
                        num_classes=params.num_classes,
                        dwn_group_types=params.dwn_group_types,
                        dwn_orders=params.dwn_orders,
                        spatial_subsampling_factors=params.spatial_subsampling_factors,
                        subsampling_factors=params.subsampling_factors,
                        domain=params.domain,
                        pooling_type=params.pooling_type,
                        apply_antialiasing=params.apply_antialiasing,
                        cannonicalize=params.cannonicalize,
                        dropout_rate=params.dropout_rate,
                        antialiasing_kwargs=params.antialiasing_kwargs,
                        layer_kwargs=params.layer_kwargs,
                        fully_convolutional=params.fully_convolutional
        )
    else:
        raise ValueError(f'Model {params.model} not found')
    return model