__all__ = ["validation_scores", "get_current_sparsity"]


def validation_scores(model, loader, criterion, accuracy_funcs, device):
    loss, accs = 0, []

    for _ in accuracy_funcs:
        accs.append(0)

    model.eval()
    for m, (batch_data, batch_labels) in enumerate(loader, start=1):
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)

        preds = model.forward(batch_data)

        loss += criterion(preds, batch_labels).item()

        for n, accuracy in enumerate(accuracy_funcs):
            accs[n] += accuracy(preds, batch_labels).item()

    total_loss = loss / m
    total_accs = [acc / m for acc in accs]

    return {
        "loss": total_loss,
        "accuracy": total_accs,
    }


def get_current_pruning_modifier(manager, epoch):
    for pruning_modifier in manager.pruning_modifiers:
        if pruning_modifier.start_epoch <= epoch < pruning_modifier.end_epoch:
            return pruning_modifier
    return None


def mean_value(values):
    if isinstance(values, (int, float)):
        return values
    if isinstance(values, list):
        return sum(values) / len(values)
    else:
        return 0.0


def get_current_sparsity(manager, epoch):
    current_pruning_modifier = get_current_pruning_modifier(manager, epoch)
    sparsity = 0.0
    if current_pruning_modifier is not None:
        sparsity = mean_value(current_pruning_modifier.applied_sparsity)
    return sparsity
