import torch
from torch.autograd import Variable
from utils.get_closure import get_optimizer_closure


def predict(model, data_loader, weight_decay):
    model.eval()
    closure = get_optimizer_closure(model, return_output=True)
    
    total_loss = 0
    correct = 0
    count = 0
    grads = [0] * len(list(model.parameters()))
    for input, target in data_loader:
        if torch.cuda.is_available():
            input, target = input.cuda(), target.cuda()

        loss, output = closure(input, target)

        #print("loss: ", loss, "loss.data: ", loss.data)
        for i, p in enumerate(model.parameters()):
            if p.grad is None:
                continue
            grads[i] += (p.grad + weight_decay * p)/len(data_loader)

        total_loss += loss.data + weight_decay/2 * compute_params_squared_l2_norm(model)
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum().numpy()
    #print("total loss: ", loss, "len(data_loader): ", len(data_loader))
    total_loss /= len(data_loader) # training = 49; testing = 10
    #print("loss after normalization: ", loss)
    acc   = 100.*correct/len(data_loader.dataset)
    total_loss = float(total_loss.detach().cpu().numpy())

    sq_grad_norm = 0
    for grad in grads:
        sq_grad_norm += grad.norm()**2
    sq_grad_norm = float(sq_grad_norm.detach().cpu().numpy())
    return total_loss, acc, sq_grad_norm


def compute_params_squared_l2_norm(model):
    params = model.parameters()
    sq_norm = 0
    for p in params:
        sq_norm += p.norm()**2
    return sq_norm
