import torch

from model_selection import register_model_selection_algorithm


@register_model_selection_algorithm("TB")
def target_best(target_test_loss, **kwargs):
    _ = kwargs

    # selectm minimum model
    min_index = torch.min(target_test_loss, dim=0).indices
    model_weights = torch.zeros_like(target_test_loss)
    model_weights[min_index] = 1

    return model_weights
