from utils import flip_signs
from sklearn.metrics import roc_auc_score, average_precision_score
import torch


def train_epoch(args, model, loader, criterion, optimizer=None):
    epoch_loss, epoch_ap = 0, 0

    for _, data in enumerate(loader):
        data = data.to(args['device'])

        if optimizer:
            if args['struc_info_type'] == 'laplacian':
                sign_flip = torch.rand(data.p.size(1)).to(data.p.device)
                sign_flip[sign_flip >= 0.5] = 1.0
                sign_flip[sign_flip < 0.5] = -1.0
                data.p = data.p * sign_flip.unsqueeze(0)

            model.train()
            optimizer.zero_grad()
        else:
            model.eval()

        out = model(data)
        loss = criterion(out, data.y)

        epoch_loss += loss.item() * len(data.y)

        if args['task'] == 'class':
            out, labels = out.detach().cpu().numpy(), data.y.detach().cpu().numpy()
            epoch_ap += average_precision_score(labels, out) * len(data.y)

        if optimizer:
            loss.backward()
            optimizer.step()

    results = {'Loss': epoch_loss / len(loader.dataset)}
    if args['task'] == 'class':
        results['AP'] = epoch_ap / len(loader.dataset)

    return results
