from .optimizer_factory import optimizer_factory
from .ffnn_factory import FFNNModel
from .transformer_factory import TransformerTokenMZModel
from .transformer_factory import TransformerNumMZModel
from .transformer_factory import TransformerSinMZModel

def model_factory(
        config, num_workers=None, prefetch_factor=None, batch_size=None, num_batch_per_update=None,
        precision=None, **machine_params):
    num_workers = ((config['train']['num_workers']
            if 'train' in config and 'num_workers' in config['train'] else None)
        if num_workers is None else num_workers)
    prefetch_factor = ((config['train']['prefetch_factor']
            if 'train' in config and 'prefetch_factor' in config['train'] else None)
        if prefetch_factor is None else prefetch_factor)
    batch_size = ((config['train']['batch_size']
            if 'train' in config and 'batch_size' in config['train'] else None)
        if batch_size is None else batch_size)
    num_batch_per_update = ((config['train']['num_batch_per_update']
            if 'train' in config and 'num_batch_per_update' in config['train'] else None)
        if num_batch_per_update is None else num_batch_per_update)
    precision = ((config['train']['precision']
            if 'train' in config and 'precision' in config['train'] else 32)
        if precision is None else precision)
    machine_params = ((config['machine_params'] if 'machine_params' in config else dict())
        if not machine_params else machine_params)
    wandb_args = (config['train']['wandb']
        if 'train' in config and 'wandb' in config['train'] else None)
    gradient_clip_val = (config['train']['gradient_clip_val']
        if 'train' in config and 'gradient_clip_val' in config['train'] else 0.0)
    train_data_path = config['train_data_path'] if 'train_data_path' in config else None
    test_data_path = config['test_data_path'] if 'test_data_path' in config else None
    augment_spectrum = (config['train']['augment_spectrum']
        if 'train' in config and 'augment_spectrum' in config['train'] else None)
    model_path = config['model_path'] if 'model_path' in config else None
    num_epochs = (config['train']['num_epochs']
        if 'train' in config and 'num_epochs' in config['train'] else None)
    optimizer = optimizer_factory(config) if 'optim' in config else None

    if 'model_type' in config:
        if config['model_type'] == 'ffnn':
            layers = (config['net']['layers']
                if 'net' in config and 'layers' in config['net'] else None)
            dropout = (config['net']['dropout']
                if 'net' in config and 'dropout' in config['net'] else None)
            batch_norm = (config['net']['batch_norm']
                if 'net' in config and 'batch_norm' in config['net'] else None)
            l1 = (config['net']['l1']
                if 'net' in config and 'l1' in config['net'] else None)
            l2 = (config['net']['l2']
                if 'net' in config and 'l2' in config['net'] else None)

            model = FFNNModel(
                train_data_path, test_data_path, model_path, num_epochs, batch_size,
                num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
                machine_params, wandb_args, gradient_clip_val, precision, layers, dropout,
                batch_norm, l1, l2)
        elif config['model_type'] == 'transformer':
            model_head = config['net']['model_head']
            model_foot = config['net']['model_foot']
            model_pred_head = config['net']['model_pred_head']
            model_body = {
                'dropout': config['net']['dropout'],
                'num_sa_layers': config['net']['num_sa_layers'],
                'embd_dim': config['net']['embd_dim'],
                'num_heads': config['net']['num_heads'],
                'ff_dim': config['net']['ff_dim']}
            peak_limits = config['peak_limits']

            if model_foot['type'] == 'token_mz':
                model = TransformerTokenMZModel(
                    train_data_path, test_data_path, model_path, num_epochs, batch_size,
                    num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
                    machine_params, wandb_args, gradient_clip_val, precision, peak_limits,
                    model_body, model_foot, model_head, model_pred_head)
            elif model_foot['type'] == 'num_mz':
                model = TransformerNumMZModel(
                    train_data_path, test_data_path, model_path, num_epochs, batch_size,
                    num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
                    machine_params, wandb_args, gradient_clip_val, precision, peak_limits,
                    model_body, model_foot, model_head, model_pred_head)
            elif model_foot['type'] == 'sin_mz':
                model = TransformerSinMZModel(
                    train_data_path, test_data_path, model_path, num_epochs, batch_size,
                    num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
                    machine_params, wandb_args, gradient_clip_val, precision, peak_limits,
                    model_body, model_foot, model_head, model_pred_head)
            else:
                raise ValueError('Unknown model foot type')
        else:
            raise ValueError('Unknown model type')
    else:
        raise ValueError('config has no key ''model_tyoe''')

    model.setup()

    return model
