import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score


class LogReg(nn.Module):
    def __init__(self, hid_dim, n_classes):
        super(LogReg, self).__init__()

        self.fc = nn.Linear(hid_dim, n_classes)

    def forward(self, x):
        ret = self.fc(x)
        return ret
    

def unsupervised_eval_linear(data, embeds, args, device):
    n_node = data.x.shape[0]
    results = []

    label = data.y if args.dataset in ['roman_empire', 'amazon_ratings'] else data.y.to(torch.float)
    n_classes = torch.unique(label).shape[0]
    # print(n_classes)
    label = label.to(device)

    for i in range(10):
        assert label.shape[0] == n_node

        train_mask, val_mask, test_mask = data.train_mask[:, i].to(device), data.val_mask[:, i].to(device), data.test_mask[:, i].to(device)

        assert torch.sum(train_mask + val_mask + test_mask) == n_node

        train_embs = embeds[train_mask]
        val_embs = embeds[val_mask]
        test_embs = embeds[test_mask]

        train_labels = label[train_mask]
        val_labels = label[val_mask]
        test_labels = label[test_mask]

        best_val_acc = 0
        eval_acc = 0
        bad_counter = 0

        n_classes = n_classes if args.dataset in ['roman_empire', 'amazon_ratings'] else 1

        logreg = LogReg(hid_dim=embeds.shape[1], n_classes=n_classes)
        opt = torch.optim.Adam(logreg.parameters(), lr=args.lr2, weight_decay=args.wd2)
        logreg = logreg.to(device)

        loss_fn = nn.CrossEntropyLoss() if args.dataset in ['roman_empire', 'amazon_ratings'] else nn.BCEWithLogitsLoss()

        for epoch in range(2000):
            logreg.train()
            opt.zero_grad()
            logits = logreg(train_embs)
            logits = logits if args.dataset in ['roman_empire', 'amazon_ratings'] else logits.squeeze(-1)

            loss = loss_fn(logits, train_labels)
            loss.backward()
            opt.step()

            logreg.eval()
            with torch.no_grad():
                val_logits = logreg(val_embs)
                test_logits = logreg(test_embs)

                if args.dataset in ['roman_empire', 'amazon_ratings']:
                    val_preds = torch.argmax(val_logits, dim=1)
                    test_preds = torch.argmax(test_logits, dim=1)
                    val_acc = torch.sum(val_preds == val_labels).float() / val_labels.shape[0]
                    test_acc = torch.sum(test_preds == test_labels).float() / test_labels.shape[0]
                else:
                    val_acc = roc_auc_score(y_true=val_labels.cpu().numpy(), y_score=val_logits.squeeze(-1).cpu().numpy())
                    test_acc = roc_auc_score(y_true=test_labels.cpu().numpy(), y_score=test_logits.squeeze(-1).cpu().numpy())

                if val_acc >= best_val_acc:
                    bad_counter = 0
                    best_val_acc = val_acc
                    if test_acc > eval_acc:
                        eval_acc = test_acc
                else:
                    bad_counter += 1

        # print(i, 'Linear evaluation accuracy:{:.4f}'.format(eval_acc))
        if torch.is_tensor(eval_acc):
            results.append(eval_acc.cpu().data)
        else:
            results.append(eval_acc)
    return results