import torch
from torch.autograd import Variable
from utils.get_closure import get_optimizer_closure


def predict(model, data_loader, weight_decay):
    model.eval()
    closure = get_optimizer_closure(model, return_output=True)
    
    loss = 0
    correct = 0
    count = 0
    grads = [0] * len(list(model.parameters()))
    for input, target in data_loader:
        input, target = input.cuda(), target.cuda()

        loss, output = closure(input, target)
        for i, p in enumerate(model.parameters()):
            if p.grad is None:
                continue
            grads[i] += (p.grad + weight_decay * p)/len(data_loader)

        loss += loss.data + weight_decay/2 * compute_params_squared_l2_norm(model)
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum().numpy()
    loss /= len(data_loader)
    acc   = 100.*correct/len(data_loader.dataset)
    loss = float(loss.detach().cpu().numpy())

    sq_grad_norm = 0
    for grad in grads:
        sq_grad_norm += grad.norm()**2
    sq_grad_norm = float(sq_grad_norm.detach().cpu().numpy())
    return loss, acc, sq_grad_norm


def compute_params_squared_l2_norm(model):
    params = model.parameters()
    sq_norm = 0
    for p in params:
        sq_norm += p.norm()**2
    return sq_norm
