import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv
import argparse

class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)

    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
        return h

def evaluate(model, label, dataloader, inv_target):
    model.eval()
    eval_logits = []
    eval_seeds = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    return  accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()

def train(device, g, target_idx, labels, train_mask, model):
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
    train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                              batch_size=100, shuffle=True)
    # no separate validation subset, use train index instead for validation
    val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                            batch_size=100, shuffle=False)
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, labels, val_loader, inv_target)
        print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
              . format(epoch, total_loss / (it+1), acc))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
    parser.add_argument("--dataset", type=str, default="aifb",
                        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
    args = parser.parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training with DGL built-in RGCN module with sampling.')

    # load and preprocess dataset
    if args.dataset == 'aifb':
        data = AIFBDataset()
    elif args.dataset == 'mutag':
        data = MUTAGDataset()
    elif args.dataset == 'bgs':
        data = BGSDataset()
    elif args.dataset == 'am':
        data = AMDataset()
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
    labels = g.nodes[category].data.pop('labels').to(device)
    train_mask = g.nodes[category].data.pop('train_mask')
    test_mask = g.nodes[category].data.pop('test_mask')
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
    g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
    g.ndata['type_id'] = g.ndata.pop(dgl.NID)
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
    inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)

    # create RGCN model
    in_size = g.num_nodes() # featureless with one-hot encoding
    out_size = data.num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)

    train(device, g, target_idx, labels, train_mask, model)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
    test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, 
                             batch_size=32, shuffle=False)
    acc = evaluate(model, labels, test_loader, inv_target)
    print("Test accuracy {:.4f}".format(acc))
