""" Encoder configs """
linear_encoder = {
    'name': 'linear',
}

embedding_encoder = {
    'name': 'embedding',
}

feature_decoder = {
    'name': 'feature',
    'mode': 'last',
}


multiclass_classification = {
    '_target_': 'tasks.tasks.GeneralTask',
    'encoder': linear_encoder,
    'decoder': feature_decoder,
    'loss': 'cross_entropy',
    'metrics': ['accuracy'],
}


mse_regression = {
    '_target_': 'tasks.tasks.GeneralTask',
    'encoder': linear_encoder,
    'decoder': feature_decoder,
    'loss': 'mse',
}

# Autoregressive regression (every timestep)
regression = mse_regression.copy()
regression.update({
    'decoder': { 'name': 'sequence' },
})

binary_classification = multiclass_classification.copy()
binary_classification.update({
    # '_target_': 'tasks.tasks.GeneralTask',
    # 'encoder': {'name': 'linear'},
    # 'decoder': {
    #     'name': 'feature',
    #     'mode': 'last',
    # },
    'loss': 'binary_cross_entropy',
    'metrics': ['binary_accuracy'],
})

# For text classification (e.g. imdb), we need to encode with embeddings and decode using the provided length information of the sequences
text_binary_classification = binary_classification.copy()
text_binary_classification.update({
    # '_target_': 'tasks.tasks.GeneralTask',
    'encoder': {'name': 'embedding'},
    'decoder': {
        'name': 'feature',
        'mode': 'length',
    },
    # 'loss': 'binary_cross_entropy',
    # 'metrics': ['binary_accuracy'],
})

text_classification = multiclass_classification.copy()
text_classification.update({
    # '_target_': 'tasks.tasks.GeneralTask',
    'encoder': {'name': 'embedding'},
    # 'decoder': {
    #     'name': 'feature',
    #     'mode': 'last',
    # },
    # 'loss': 'cross_entropy',
    # 'metrics': ['accuracy'],
})

retrieval = {
    '_target_': 'tasks.tasks.GeneralTask',
    'encoder': embedding_encoder,
    'decoder': {
        # {'name': 'feature', 'mode': 'length'},
        'name': 'retrieval',
        'mode': 'length',
        'd_model': None,
        'nli': True,
        'activation': 'relu',
    },
    # 'loss': 'binary_cross_entropy',
    # 'metrics': ['binary_accuracy'],
    'loss': 'cross_entropy',
    'metrics': ['accuracy'],
}

lm = {
    '_target_': 'tasks.tasks.LMTask',
    # '_name_': 'lm',
    'tied': False,
    'rescale': False,
    'init': None,
}
