import torch

from model_selection import register_model_selection_algorithm


@register_model_selection_algorithm("SB")
def source_best(source_val_loss, **kwargs):
    _ = kwargs
    # weights shape: [n_models, n_samples]
    # source_val_loss shape: [n_models, n_samples]
    mean_loss = torch.mean(source_val_loss, dim=1)  # [n_models]

    # select model with minimum dev risk
    min_index = torch.argmin(mean_loss)
    model_weights = torch.zeros_like(mean_loss)
    model_weights[min_index] = 1
    
    return model_weights
