import torch
import torch.nn.functional as F

@torch.no_grad()
def evaluate(model, dataset, split_idx, eval_func, criterion, args, result=None, edge_mask=None):
    model.eval()
    with torch.no_grad():
        if edge_mask is not None:
            out = model(dataset.graph['all_feat'], edge_mask)
        else:
            out = model(dataset.graph['all_feat'], dataset.graph['edge_masks']) # [k, num_ori_nodes + num_cluster + num_global, n_cls]
        out = F.log_softmax(out, dim=-1)
        train_loss, _ = model.cls_loss(out, dataset.label, dataset.label_global, split_idx['train'], criterion)
        valid_loss, _ = model.cls_loss(out, dataset.label, dataset.label_global, split_idx['valid'], criterion)
        test_loss, _ = model.cls_loss(out, dataset.label, dataset.label_global, split_idx['test'], criterion)

        train_acc = eval_func(
            dataset.label[split_idx['train']], out[split_idx['train']])
        valid_acc = eval_func(
            dataset.label[split_idx['valid']], out[split_idx['valid']])
        test_acc = eval_func(
            dataset.label[split_idx['test']], out[split_idx['test']])
        res = [train_acc, valid_acc, test_acc, train_loss, valid_loss, test_loss, out]

        return res
    
def evaluate_with_out(out, dataset, split_idx, eval_func, criterion, args, result=None):
    with torch.no_grad():


        train_acc = eval_func(
            dataset.label[split_idx['train']], out[split_idx['train']])
        valid_acc = eval_func(
            dataset.label[split_idx['valid']], out[split_idx['valid']])
        test_acc = eval_func(
            dataset.label[split_idx['test']], out[split_idx['test']])
        
        correct_mask = dataset.label[split_idx['test']] == out[split_idx['test']].argmax(dim=-1, keepdim=True)
        res = [train_acc, valid_acc, test_acc]

        return res, correct_mask

@torch.no_grad()
def evaluate_cl(embed, head, dataset, split_idx, eval_func, criterion, args, mask=None, result=None):
    head.eval()
    with torch.no_grad():
        if result is not None:
            out = result
        else:
            out = head(embed)

        train_acc = eval_func(
            dataset.label[split_idx['train']], out[split_idx['train']])
        valid_acc = eval_func(
            dataset.label[split_idx['valid']], out[split_idx['valid']])
        test_acc = eval_func(
            dataset.label[split_idx['test']], out[split_idx['test']])

        if args.dataset in ('questions'):
            if dataset.label.shape[1] == 1:
                true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1)
            else:
                true_label = dataset.label
            valid_loss = criterion(out[split_idx['valid']], true_label.squeeze(1)[
                split_idx['valid']].to(torch.float))
        else:
            out = F.log_softmax(out, dim=1)
            valid_loss = criterion(
                out[split_idx['valid']], dataset.label.squeeze(1)[split_idx['valid']])

    return train_acc, valid_acc, test_acc, valid_loss, out

@torch.no_grad()
def evaluate_cpu(model, dataset, split_idx, eval_func, criterion, args, device, result=None):
    if result is not None:
        out = result
    else:
        model.eval()

    model.to(torch.device("cpu"))
    dataset = dataset.to(torch.device("cpu"))
    out = model(dataset.graph['node_feat'], dataset.graph['struc_feat'], dataset.graph['norm_A'])

    train_acc = eval_func(
        dataset.label[split_idx['train']], out[split_idx['train']])
    valid_acc = eval_func(
        dataset.label[split_idx['valid']], out[split_idx['valid']])
    test_acc = eval_func(
        dataset.label[split_idx['test']], out[split_idx['test']])
    if args.dataset in ('questions'):
        if dataset.label.shape[1] == 1:
            true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1)
        else:
            true_label = dataset.label
        valid_loss = criterion(out[split_idx['valid']], true_label.squeeze(1)[
            split_idx['valid']].to(torch.float))
    else:
        out = F.log_softmax(out, dim=1)
        valid_loss = criterion(
            out[split_idx['valid']], dataset.label.squeeze(1)[split_idx['valid']])

    return train_acc, valid_acc, test_acc, valid_loss, out
