import torch
from torch.nn.functional import relu


def guided_activation(feature: torch.Tensor) -> torch.Tensor:
    """
    Guided activation function for branch-level neural substitution.

    Args:
        feature ('torch.Tensor'): Output features of multiple convolutions.
        The shape is (N,C,H,W), representing number of features(not batch size), channels, height, and width, respectively.
    Returns:
        'torch.Tensor': Guided output features.
    """

    N = feature.size(0)
    gathered_feature = feature.mean(dim=0)
    gathered_feature = relu(gathered_feature)

    dead_idx = gathered_feature == 0
    dead_idx = dead_idx.repeat(N, 1, 1, 1)

    feature[dead_idx] = 0
    return feature
