import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv
import argparse

class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
                                  num_bases=num_rels, self_loop=False)
        
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
        return h
    
def evaluate(model, labels, dataloader, inv_target):
    model.eval()
    eval_logits = []
    eval_seeds = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    num_seeds = len(eval_seeds)
    loc_sum = accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()) * float(num_seeds)
    return torch.tensor([loc_sum.item(), float(num_seeds)])

def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
    # define loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
    train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                              batch_size=100, shuffle=True, use_ddp=True)
    # no separate validation subset, use train index instead for validation
    val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, 
                            batch_size=100, shuffle=False, use_ddp=True)
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
                block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # torchmetric accuracy defined as num_correct_labels / num_train_nodes
        # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes] 
        loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(device)
        dist.reduce(loc_acc_split, 0)
        if (proc_id == 0):
            acc = loc_acc_split[0] / loc_acc_split[1]
            print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
                  . format(epoch, total_loss / (it+1), acc.item()))
            
def run(proc_id, nprocs, devices, g, data):
    # find corresponding device for my rank
    device = devices[proc_id]
    torch.cuda.set_device(device)
    # initialize process group and unpack data for sub-processes
    dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=proc_id)
    num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = data
    labels = labels.to(device)
    inv_target = inv_target.to(device)
    # create RGCN model (distributed)
    in_size = g.num_nodes()
    out_size = num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
    model = DistributedDataParallel(model, device_ids=[device], output_device=device)
    # training + testing
    train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model)
    test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
    test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, 
                             batch_size=32, shuffle=False, use_ddp=True)
    loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device)
    dist.reduce(loc_acc_split, 0)
    if (proc_id == 0):
        acc = loc_acc_split[0] / loc_acc_split[1]
        print("Test accuracy {:.4f}".format(acc))
    # cleanup process group
    dist.destroy_process_group()
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling (multi-gpu)')
    parser.add_argument("--dataset", type=str, default="aifb",
                        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
    parser.add_argument("--gpu", type=str, default='0',
                           help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
                                " e.g., 0,1,2,3.")
    args = parser.parse_args()
    devices = list(map(int, args.gpu.split(',')))
    nprocs = len(devices)
    print(f'Training with DGL built-in RGCN module with sampling using', nprocs, f'GPU(s)')

    # load and preprocess dataset at master(parent) process
    if args.dataset == 'aifb':
        data = AIFBDataset()
    elif args.dataset == 'mutag':
        data = MUTAGDataset()
    elif args.dataset == 'bgs':
        data = BGSDataset()
    elif args.dataset == 'am':
        data = AMDataset()
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
    labels = g.nodes[category].data.pop('labels')
    train_mask = g.nodes[category].data.pop('train_mask')
    test_mask = g.nodes[category].data.pop('test_mask')
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
    g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
    g.ndata['type_id'] = g.ndata.pop(dgl.NID)
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
    inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype)
    # avoid creating certain graph formats and train/test indexes in each sub-process to save momory
    g.create_formats_()
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    # thread limiting to avoid resource competition
    os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // nprocs)

    data = num_rels, data.num_classes, labels, train_idx, test_idx, target_idx, inv_target
    mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)


