from tqdm import tqdm

from protonets.utils import filter_opt
from protonets.models import get_model
from protonets.data.base import CudaTransform
import torch

def load(opt):
    model_opt = filter_opt(opt, 'model')
    model_name = model_opt['model_name']

    del model_opt['model_name']

    return get_model(model_name, model_opt)

def evaluate(model, data_loader, meters, desc=None, cuda=True):
    # cuda_trans = CudaTransform()
    model.eval()

    for field,meter in meters.items():
        meter.reset()

    if desc is not None:
        data_loader = tqdm(data_loader, desc=desc)
    with torch.no_grad():
        for sample in data_loader:
            if cuda:
                for k in sample:
                    if hasattr(sample[k], 'cuda'):
                        sample[k] = sample[k].cuda()
                # cuda_trans(sample)
            _, output = model.loss(sample, eval=True)
            for k in sample:
                if hasattr(sample[k], 'cuda'):
                    sample[k] = sample[k].cpu()
            torch.cuda.empty_cache()

            for field, meter in meters.items():
                meter.add(output[field])
    return meters
