import torch
import fire


def prune_datadep(data_vals, layers=[1]):
    data_vals = torch.load(data_vals, map_location=torch.device('cpu'))
    for layer in layers:
        print(data_vals['pos_sensitivities'][layer-1].sum(dim=1).mean().item())


if __name__ == '__main__':
    fire.Fire(prune_datadep)
