import torch

def gate_hook(gate_top_k_idx, gate_score, _):
    try:
        gate_hook.enabled
    except AttributeError:
        gate_hook.enabled = False

    if gate_hook.enabled:
        gate_choice = gate_top_k_idx[:, 0]
        print(f'{gate_choice=}')
        print(f'{gate_choice.shape=}')

