import torch

from neural_mpm.nn import (
    UNet,
    FNO,
    FFNO,
    UNO,
)

ALL = {
    'unet': UNet,
    'fno': FNO,
    'ffno': FFNO,
    'uno': UNO,
}


# TODO use **config_dict in model call to pass all parameters at once
def create_model(model_type, config_dict):
    """
    Instantiates a model with the given parameters
    and returns it.

    config_dict must contain the architecture with the model parameters as
    well as the required hyperparameters such as grid_size, step_per_call, ...

    Args:
        model_type: str
            Name of the model to instantiate
        config_dict: dict
            Dictionary containing the model parameters and hyperparameters

    Returns:
        model: torch.nn.Module
            The instantiated model
    """
    in_channels = 4
    if "WBC" in config_dict['data']:
        in_channels = 6

    if model_type == 'fno':
        hidden_channels = config_dict['architecture']['hidden']
        modes = config_dict['architecture']['modes']
        if modes is None:
            modes = config_dict['grid_size'] // 2
        use_mlp = config_dict['architecture'].get('use_mlp', False)
        model = FNO(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_preds=config_dict['steps_per_call'],
            out_channels=2,
            modes=modes,
            use_mlp=use_mlp
        )
    elif model_type == 'ffno':
        hidden_channels = config_dict['architecture'][:-1]
        modes = config_dict['architecture'][-1]
        if modes is None:
            modes = config_dict['grid_size'] // 2
        model = FFNO(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_preds=config_dict['steps_per_call'],
            out_channels=2,
            modes=modes,
        )
        # TODO ffno
    elif model_type == 'unet':
        architecture = config_dict['architecture']['hidden'] + [
            config_dict['steps_per_call']]
        factors = [2] * (len(architecture) - 1)
        model = UNet(architecture, factors, in_channels=in_channels)
        #if torch.cuda.is_available():
        #model = torch.compile(model)
    elif model_type == 'uno':
        """
        Parameters
        ----------

        hidden_channels: initial width of the UNO (after lifting)
            e.g., 128
        uno_out_channels: output channels of each Fourier Layer.
            e.g., [32, 64, 64, 32] for 4 layers
        n_modes: Fourier Modes to use in integral operation of each Fourier Layers
            e.g., [5, 5, 5, 5] for 4 layers
        scalings: Scaling factors for each Fourier Layer
            e.g., [1.0, 0.5, 1.0, 2.0] for 4 layers
        """

        config_dict['architecture'].setdefault('hidden_channels', 128)
        config_dict['architecture'].setdefault('uno_out_channels',
                                               [32, 64, 64, 32])
        config_dict['architecture'].setdefault('n_modes', [5, 5, 5, 5])
        config_dict['architecture'].setdefault('scalings',
                                               [1.0, 0.5, 1.0, 2.0])

        config_dict['architecture']['n_modes'] = [
            [n_modes, n_modes] if not isinstance(n_modes, list) else n_modes
            for n_modes in config_dict['architecture']['n_modes']
        ]
        config_dict['architecture']['scalings'] = [
            [scaling, scaling] if not isinstance(scaling, list) else scaling
            for scaling in config_dict['architecture']['scalings']
        ]

        config_dict['architecture']['uno_n_modes'] = config_dict[
            'architecture'].pop('n_modes')
        config_dict['architecture']['uno_scalings'] = config_dict[
            'architecture'].pop('scalings')

        config_dict['architecture']['n_layers'] = len(
            config_dict['architecture']['uno_out_channels']
        )

        model = UNO(
            in_channels,
            2,
            num_preds=config_dict['steps_per_call'],
            **config_dict['architecture']
        )
    else:
        raise ValueError(f"Model {model_type} not recognized.\n "
                         f"Valid Types: {ALL.keys()}")

    return model


if __name__ == '__main__':
    # Test create_model
    # torch.set_default_device('cuda')

    x = torch.randn(1, 64, 64, 4)

    print('------------')
    print('Testing FNO')
    print('------------')
    model = create_model('fno', {
        'steps_per_call': 8,
        'architecture': {
            'hidden': [128, 128],
            'modes': 5,
            'use_mlp': False
        }
    })
    print(model(x).shape)

    print('--------------')
    print('Testing FFNO')
    print('--------------')
    model = create_model('ffno', {
        'steps_per_call': 8,
        'architecture': [128, 128, 5]
    })
    print(model(x).shape)

    print('--------------')
    print('Testing UNet')
    print('--------------')
    model = create_model('unet', {
        'steps_per_call': 8,
        'architecture': {
            'hidden': [128, 128]
        }
    })
    print(model(x).shape)

    print('--------------')
    print('Testing UNO')
    print('--------------')
    model = create_model('uno', {
        'steps_per_call': 8,
        'architecture': {
            'hidden_channels': 128,
            'uno_out_channels': [32, 64, 64, 32],
            'n_modes': [5, 5, 5, 5],
            'scalings': [1.0, 0.5, 1.0, 2.0]
        }
    })

    print(model(x).shape)
