import torch
from torchvision.utils import make_grid

from generalization_study.utils import collect_per_factor


def test_epoch(epoch, model, data_loader, writer, device, rsquared=None,
               mode='test'):
    model.eval()
    log = {'rsquared': [], 'mse': []}
    with torch.no_grad():
        for iteration, (batch, targets) in enumerate(data_loader):
            batch = batch.to(device)
            targets = targets.to(device)
            predictions = model(batch)
            squared_diff = (targets - predictions).pow(2)
            r_squared_per_factor = rsquared(predictions, targets)

            # bookkeeping
            log['rsquared'].append(r_squared_per_factor.detach())
            log['mse'].append(squared_diff.mean(dim=0).detach())

            if iteration == 0:
                grid = make_grid(batch[:64], pad_value=1)
                writer.add_image('test/batch', grid, epoch)

    collect_per_factor(log['rsquared'], epoch, 'rsquared', mode,
                       data_loader.dataset.factor_names, writer)
    collect_per_factor(log['mse'], epoch, 'mse', mode,
                       data_loader.dataset.factor_names, writer,
                       aggregate_fct=torch.sum)
    return
