import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl.data.ppi import PPIDataset
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import f1_score

class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads):
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # three-layer GAT
        self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu))
        self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], hid_size, heads[1], residual=True, activation=F.elu))
        self.gat_layers.append(dglnn.GATConv(hid_size*heads[1], out_size, heads[2], residual=True, activation=None))

    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
            if i == 2:  # last layer 
                h = h.mean(1)
            else:       # other layer(s)
                h = h.flatten(1)
        return h

def evaluate(g, features, labels, model):
    model.eval()
    with torch.no_grad():
        output = model(g, features)
        pred = np.where(output.data.cpu().numpy() >= 0, 1, 0)
        score = f1_score(labels.data.cpu().numpy(), pred, average='micro')
        return score

def evaluate_in_batches(dataloader, device, model):
    total_score = 0
    for batch_id, batched_graph in enumerate(dataloader):
        batched_graph = batched_graph.to(device)
        features = batched_graph.ndata['feat']
        labels = batched_graph.ndata['label']
        score = evaluate(batched_graph, features, labels, model)
        total_score += score
    return total_score / (batch_id + 1) # return average score

def train(train_dataloader, val_dataloader, device, model):
    # define loss function and optimizer
    loss_fcn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)

    # training loop
    for epoch in range(400):
        model.train()
        logits = []
        total_loss = 0
        # mini-batch loop
        for batch_id, batched_graph in enumerate(train_dataloader):
            batched_graph = batched_graph.to(device)
            features = batched_graph.ndata['feat'].float()
            labels = batched_graph.ndata['label'].float()
            logits = model(batched_graph, features)
            loss = loss_fcn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print("Epoch {:05d} | Loss {:.4f} |". format(epoch, total_loss / (batch_id + 1) ))

        if (epoch + 1) % 5 == 0:
            avg_score = evaluate_in_batches(val_dataloader, device, model) # evaluate F1-score instead of loss
            print("                            Acc. (F1-score) {:.4f} ". format(avg_score))


if __name__ == '__main__':
    print(f'Training PPI Dataset with DGL built-in GATConv module.')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # load and preprocess datasets
    train_dataset = PPIDataset(mode='train')
    val_dataset = PPIDataset(mode='valid')
    test_dataset = PPIDataset(mode='test')
    features = train_dataset[0].ndata['feat']

    # create GAT model
    in_size = features.shape[1]
    out_size = train_dataset.num_labels
    model = GAT(in_size, 256, out_size, heads=[4,4,6]).to(device)

    # model training
    print('Training...')
    train_dataloader = GraphDataLoader(train_dataset, batch_size=2)
    val_dataloader = GraphDataLoader(val_dataset, batch_size=2)
    train(train_dataloader, val_dataloader, device, model)

    # test the model
    print('Testing...')
    test_dataloader = GraphDataLoader(test_dataset, batch_size=2)
    avg_score = evaluate_in_batches(test_dataloader, device, model)
    print("Test Accuracy (F1-score) {:.4f}".format(avg_score))
