import torch
from tqdm import tqdm
def cal_logits(model, dataloader, use_gpu = True, use_tqdm = True, device = None):
    model.eval()
    with torch.no_grad():
        if use_tqdm:
            progress_bar = tqdm(total=len(dataloader))
        logits_ = []
        for i, samples in enumerate(dataloader):
            images = samples[0]
            if use_gpu:
                images = images.to(device)
            logits, snr = model(images)
            logits_.append(logits)
            if use_tqdm:
                progress_bar.update()
        logits_ = torch.cat(logits_, dim = 0) 
        if use_tqdm:
            progress_bar.close()
        return logits_