
import torch
import numpy as np
from alethiometer import calc_zc_metrics  

def get_tcet(net, xloader, device=torch.cuda.current_device()):
    results = calc_zc_metrics(metrics=['tcet_snip_none'], model=net, train_queue=xloader, device=device, aggregate=True)
    return results['tcet_snip_none']