import torch

from .flood_echo import FloodModel
from .recgnn import RecGNN
from .gin import GIN
from .pgn import PGN

def get_model(config):
    if config.normalization == 'LayerNorm':
        normalization_function = torch.nn.LayerNorm
    elif config.normalization == 'None':
        normalization_function = torch.nn.Identity
    else:
        print('Misspecified normalization function ' + config.normalization)
        exit(1)

    if config.activation == 'ReLU':
        activation = torch.nn.ReLU()
    else:
        print("Unknown activation " + config.activation)
        exit(1)


    if config.model == 'GIN':
        model = GIN(
            input_dim       = config.in_channels,
            output_dim      = config.out_channels,

            num_rounds      = config.num_rounds,

            hidden_dim      = config.hidden_dimension,
            hidden_state_factor = config.hidden_state_factor,
            mlp_depth = config.mlp_depth,
            dropout = config.dropout,
            normalization = normalization_function,
            activation = activation,

            aggregation = config.aggregation,
            conv = config.conv,
            random_init = False,
            prediction_mode = config.prediction_mode,
        )

    elif config.model == 'rGIN': 
        model = GIN(
            input_dim       = config.in_channels,
            output_dim      = config.out_channels,

            num_rounds      = config.num_rounds,

            hidden_dim      = config.hidden_dimension,
            hidden_state_factor = config.hidden_state_factor,
            mlp_depth = config.mlp_depth,
            dropout = config.dropout,
            normalization = normalization_function,
            activation = activation,

            aggregation = config.aggregation,
            conv = config.conv,
            random_init = True,
            prediction_mode = config.prediction_mode,
        )     

    elif config.model == 'RecGNN':
        model = RecGNN(
            input_dim       = config.in_channels,
            output_dim      = config.out_channels, 

            num_rounds      = config.num_rounds, 
            num_k           = config.num_k, 


            hidden_dim      = config.hidden_dimension, 
            hidden_state_factor = config.hidden_state_factor,
            mlp_depth = config.mlp_depth,
            dropout = config.dropout,
            normalization = normalization_function,
            activation = activation,

            aggregation = config.aggregation, 
            conv = config.conv
        )
    elif config.model == 'PGN':
        model = PGN(
            input_dim       = config.in_channels,
            output_dim      = config.out_channels, 

            num_rounds      = config.num_rounds, 
            num_k           = config.num_k, 


            hidden_dim      = config.hidden_dimension, 
            hidden_state_factor = config.hidden_state_factor,
            mlp_depth = config.mlp_depth,
            dropout = config.dropout,
            normalization = normalization_function,
            activation = activation,

            aggregation = config.aggregation, 
            conv = config.conv
        )
    elif config.model == 'FloodEcho':
        model = FloodModel(
            input_dim       = config.in_channels,
            output_dim      = config.out_channels, 

            num_rounds      = config.num_rounds, 
            num_k           = config.start_k, 


            hidden_dim      = config.hidden_dimension, 
            hidden_state_factor = config.hidden_state_factor,
            mlp_depth = config.mlp_depth,
            dropout = config.dropout,
            normalization = normalization_function,
            activation = activation,

            aggregation = config.aggregation, 
            conv = config.conv,
            
            prediction_mode = config.prediction_mode,
            start_mode = config.start_mode,
            start_selection = config.start_selection,
            pool_mode = config.pool_mode,
            train_mode = config.train_mode,
        )
    else:
        print("unknown model " + config.model)
        exit(1)
    return model

