import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.function as fn
import dgl.nn as dglnn
from dgl.dataloading import NeighborSampler, DataLoader
from dgl import apply_each
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm

class HeteroGAT(nn.Module):
    def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(hid_size, out_size)   # Should be HeteroLinear

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            # One thing is that h might return tensors with zero rows if the number of dst nodes
            # of one node type is 0.  x.view(x.shape[0], -1) wouldn't work in this case.
            h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2]))
            if l != len(self.layers) - 1:
                h = apply_each(h, F.relu)
                h = apply_each(h, self.dropout)
        return self.linear(h['paper'])

def evaluate(model, dataloader, desc):
    preds = []
    labels = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc=desc):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']['paper'][:, 0]
            y_hat = model(blocks, x)
            preds.append(y_hat.cpu())
            labels.append(y.cpu())
        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = MF.accuracy(preds, labels)
        return acc

def train(train_loader, val_loader, test_loader, model):
    # loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    # training loop
    for epoch in range(10):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(tqdm.tqdm(train_dataloader, desc="Train")):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']['paper'][:, 0]
            y_hat = model(blocks, x)
            loss = loss_fcn(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        model.eval()
        val_acc = evaluate(model, val_dataloader, 'Val. ')
        test_acc = evaluate(model, test_dataloader, 'Test ')
        print(f'Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}')

if __name__ == '__main__':
    print(f'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load and preprocess dataset
    print('Loading data')
    dataset = DglNodePropPredDataset('ogbn-mag')
    graph, labels = dataset[0]
    graph.ndata['label'] = labels
    # add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
    graph = dgl.AddReverse()(graph)
    # precompute the author, topic, and institution features
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes')
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic')
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with')
    # find train/val/test indexes
    split_idx = dataset.get_idx_split()
    train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
    train_idx = apply_each(train_idx, lambda x: x.to(device))
    val_idx = apply_each(val_idx, lambda x: x.to(device))
    test_idx = apply_each(test_idx, lambda x: x.to(device))

    # create RGAT model
    in_size = graph.ndata['feat']['paper'].shape[1]
    out_size = dataset.num_classes
    model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device)

    # dataloader + model training + testing
    train_sampler = NeighborSampler([5, 5, 5],
                                    prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
                                    prefetch_labels={'paper': ['label']})
    val_sampler = NeighborSampler([10, 10, 10],
                                  prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
                                  prefetch_labels={'paper': ['label']})
    train_dataloader = DataLoader(graph, train_idx, train_sampler,
                                  device=device, batch_size=1000, shuffle=True,
                                  drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
    val_dataloader = DataLoader(graph, val_idx, val_sampler,
                                device=device, batch_size=1000, shuffle=False,
                                drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
    test_dataloader = DataLoader(graph, test_idx, val_sampler,
                                 device=device, batch_size=1000, shuffle=False,
                                 drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())

    train(train_dataloader, val_dataloader, test_dataloader, model)
