import torch
import numpy as np
import torch.nn.functional as F



# Signature at least = model, test_dataloader, loss_function
# Used to compute error metrics through training

def test_error_QM7b(model, test_dataloader, loss_function, relative=False):
    """
    :return: the average of loss_function(pred, labels)
    """
    errors = []
    for batched_graph, labels in test_dataloader:
        with torch.no_grad():
            pred = model(batched_graph, batched_graph.ndata['attr'].float())
        errors.append([loss_function(pred[:, i], labels[:, i]).item() for i in range(labels.shape[1])])
    if relative:
        # Path not used anymore as it doesn't match the way benchmark models are evaluated
        # Would require to pass 'dataset' through train_model
        # errors = np.sum(errors, axis=0)
        ### Sum of absolute values for each feature
        # total_values = torch.sum(torch.abs(dataset.label), axis=0).tolist()
        # Relative error | mean over all features
        # return np.mean(errors / total_values)
        raise NotImplementedError
    errors = np.mean(errors, axis=0)
    return np.mean(errors)


def test_error_DBLP(model, test_dataloader, loss_function=None):
    """
    :param loss_function: not used but required by the error_function interface
    :return: the ratio of wrongly predicted labels
    """
    num_correct = 0
    num_tests = 0
    for batched_graph, labels in test_dataloader:
        with torch.no_grad():
            pred = model(batched_graph, batched_graph.ndata['attr'].float())
        # if model.flatten_labels:
        try:
            labels = labels[:, 0]
        except IndexError:
            pass
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)
    return (num_tests - num_correct) / num_tests

def test_error_QM9(model, test_dataloader, loss_function=None):
    """
    :param loss_function: not used but required by the error_function interface
    :return: the ratio of wrongly predicted labels
    """
    num_tests = 0
    loss = 0
    for batched_graph, labels in test_dataloader:
        with torch.no_grad():
            pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss += F.l1_loss(pred, labels, reduction='sum')
        num_tests += len(labels)
    return loss / num_tests


def accuracy_error_DBLP(pred, labels, reduction=None):
    labels = labels.reshape((-1,))
   # print(labels.shape, pred.shape)
    return len(labels) - (pred.argmax(1) == labels).sum()

