import argparse
import time
from contextlib import contextmanager

import numpy as np
import sklearn.linear_model as lm
import sklearn.metrics as skm
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn

class DistSAGE(nn.Module):
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for i, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if i != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without
        neighbor sampling).

        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any
        number of nodes and layers.
        """
        # During inference with sampling, multi-layer blocks are very
        # inefficient because lots of computations in the first few layers are
        # repeated. Therefore, we compute the representation of all nodes layer
        # by layer.  The nodes on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = dgl.distributed.node_split(
            np.arange(g.num_nodes()),
            g.get_partition_book(),
            force_even=True,
        )
        y = dgl.distributed.DistTensor(
            (g.num_nodes(), self.n_hidden),
            th.float32,
            "h",
            persistent=True,
        )
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
                y = dgl.distributed.DistTensor(
                    (g.num_nodes(), self.n_classes),
                    th.float32,
                    "h_last",
                    persistent=True,
                )
            # Create sampler
            sampler = dgl.dataloading.NeighborSampler([-1])
            # Create dataloader
            dataloader = dgl.dataloading.DistNodeDataLoader(
                g,
                nodes,
                sampler,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(device)
                h = x[input_nodes].to(device)
                h_dst = h[: block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if i != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

    @contextmanager
    def join(self):
        """dummy join for standalone"""
        yield


def load_subtensor(g, input_nodes, device):
    """
    Copys features and labels of a set of nodes onto GPU.
    """
    batch_inputs = g.ndata["features"][input_nodes].to(device)
    return batch_inputs


class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata["h"] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            pos_score = pos_graph.edata["score"]
        with neg_graph.local_scope():
            neg_graph.ndata["h"] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            neg_score = neg_graph.edata["score"]

        score = th.cat([pos_score, neg_score])
        label = th.cat(
            [th.ones_like(pos_score), th.zeros_like(neg_score)]
        ).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss


def generate_emb(model, g, inputs, batch_size, device):
    """
    Generate embeddings for each node
    g : The entire graph.
    inputs : The features of all the nodes.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)

    return pred


def compute_acc(emb, labels, train_nids, val_nids, test_nids):
    """
    Compute the accuracy of prediction given the labels.

    We will fist train a LogisticRegression model using the trained embeddings,
    the training set, validation set and test set is provided as the arguments.

    The final result is predicted by the lr model.

    emb: The pretrained embeddings
    labels: The ground truth
    train_nids: The training set node ids
    val_nids: The validation set node ids
    test_nids: The test set node ids
    """

    emb = emb[np.arange(labels.shape[0])].cpu().numpy()
    train_nids = train_nids.cpu().numpy()
    val_nids = val_nids.cpu().numpy()
    test_nids = test_nids.cpu().numpy()
    labels = labels.cpu().numpy()

    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)
    lr = lm.LogisticRegression(multi_class="multinomial", max_iter=10000)
    lr.fit(emb[train_nids], labels[train_nids])

    pred = lr.predict(emb)
    eval_acc = skm.accuracy_score(labels[val_nids], pred[val_nids])
    test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids])
    return eval_acc, test_acc


def run(args, device, data):
    # Unpack data
    (
        train_eids,
        train_nids,
        in_feats,
        g,
        global_train_nid,
        global_valid_nid,
        global_test_nid,
        labels,
    ) = data
    # Create sampler
    neg_sampler = dgl.dataloading.negative_sampler.Uniform(args.num_negs)
    sampler = dgl.dataloading.NeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(",")]
    )
    # Create dataloader
    exclude = "reverse_id" if args.remove_edge else None
    reverse_eids = th.arange(g.num_edges()) if args.remove_edge else None
    dataloader = dgl.dataloading.DistEdgeDataLoader(
        g,
        train_eids,
        sampler,
        negative_sampler=neg_sampler,
        exclude=exclude,
        reverse_eids=reverse_eids,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
    )
    # Define model and optimizer
    model = DistSAGE(
        in_feats,
        args.num_hidden,
        args.num_hidden,
        args.num_layers,
        F.relu,
        args.dropout,
    )
    model = model.to(device)
    if not args.standalone:
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
            dev_id = g.rank() % args.num_gpus
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[dev_id], output_device=dev_id
            )
    loss_fcn = CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    epoch = 0
    for epoch in range(args.num_epochs):
        num_seeds = 0
        num_inputs = 0

        step_time = []
        sample_t = []
        feat_copy_t = []
        forward_t = []
        backward_t = []
        update_t = []
        iter_tput = []

        start = time.time()
        with model.join():
            # Loop over the dataloader to sample the computation dependency
            # graph as a list of blocks.
            for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(
                dataloader
            ):
                tic_step = time.time()
                sample_t.append(tic_step - start)

                copy_t = time.time()
                pos_graph = pos_graph.to(device)
                neg_graph = neg_graph.to(device)
                blocks = [block.to(device) for block in blocks]
                batch_inputs = load_subtensor(g, input_nodes, device)
                copy_time = time.time()
                feat_copy_t.append(copy_time - copy_t)

                # Compute loss and prediction
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, pos_graph, neg_graph)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_t.append(forward_end - copy_time)
                backward_t.append(compute_end - forward_end)

                # Aggregate gradients in multiple nodes.
                optimizer.step()
                update_t.append(time.time() - compute_end)

                pos_edges = pos_graph.num_edges()

                step_t = time.time() - start
                step_time.append(step_t)
                iter_tput.append(pos_edges / step_t)
                num_seeds += pos_edges
                if step % args.log_every == 0:
                    print(
                        "[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed "
                        "(samples/sec) {:.4f} | time {:.3f}s | sample {:.3f} | "
                        "copy {:.3f} | forward {:.3f} | backward {:.3f} | "
                        "update {:.3f}".format(
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            np.mean(iter_tput[3:]),
                            np.sum(step_time[-args.log_every:]),
                            np.sum(sample_t[-args.log_every:]),
                            np.sum(feat_copy_t[-args.log_every:]),
                            np.sum(forward_t[-args.log_every:]),
                            np.sum(backward_t[-args.log_every:]),
                            np.sum(update_t[-args.log_every:]),
                        )
                    )
                start = time.time()

        print(
            "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, "
            "forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
            "#inputs: {}".format(
                g.rank(),
                np.sum(step_time),
                np.sum(sample_t),
                np.sum(feat_copy_t),
                np.sum(forward_t),
                np.sum(backward_t),
                np.sum(update_t),
                num_seeds,
                num_inputs,
            )
        )
        epoch += 1

    # evaluate the embedding using LogisticRegression
    pred = generate_emb(
        model if args.standalone else model.module,
        g,
        g.ndata["features"],
        args.batch_size_eval,
        device,
    )
    if g.rank() == 0:
        eval_acc, test_acc = compute_acc(
            pred, labels, global_train_nid, global_valid_nid, global_test_nid
        )
        print("eval acc {:.4f}; test acc {:.4f}".format(eval_acc, test_acc))

    # sync for eval and test
    if not args.standalone:
        th.distributed.barrier()

    if not args.standalone:
        g._client.barrier()

        # save features into file
        if g.rank() == 0:
            th.save(pred, "emb.pt")
    else:
        th.save(pred, "emb.pt")


def main(args):
    dgl.distributed.initialize(args.ip_config)
    if not args.standalone:
        th.distributed.init_process_group(backend="gloo")
    g = dgl.distributed.DistGraph(
            args.graph_name, part_config=args.part_config
        )
    print("rank:", g.rank())
    print("number of edges", g.num_edges())

    train_eids = dgl.distributed.edge_split(
        th.ones((g.num_edges(),), dtype=th.bool),
        g.get_partition_book(),
        force_even=True,
    )
    train_nids = dgl.distributed.node_split(
        th.ones((g.num_nodes(),), dtype=th.bool), g.get_partition_book()
    )
    global_train_nid = th.LongTensor(
        np.nonzero(g.ndata["train_mask"][np.arange(g.num_nodes())])
    )
    global_valid_nid = th.LongTensor(
        np.nonzero(g.ndata["val_mask"][np.arange(g.num_nodes())])
    )
    global_test_nid = th.LongTensor(
        np.nonzero(g.ndata["test_mask"][np.arange(g.num_nodes())])
    )
    labels = g.ndata["labels"][np.arange(g.num_nodes())]
    if args.num_gpus == -1:
        device = th.device("cpu")
    else:
        dev_id = g.rank() % args.num_gpus
        device = th.device("cuda:" + str(dev_id))

    # Pack data
    in_feats = g.ndata["features"].shape[1]
    global_train_nid = global_train_nid.squeeze()
    global_valid_nid = global_valid_nid.squeeze()
    global_test_nid = global_test_nid.squeeze()
    print("number of train {}".format(global_train_nid.shape[0]))
    print("number of valid {}".format(global_valid_nid.shape[0]))
    print("number of test {}".format(global_test_nid.shape[0]))
    data = (
        train_eids,
        train_nids,
        in_feats,
        g,
        global_train_nid,
        global_valid_nid,
        global_test_nid,
        labels,
    )
    run(args, device, data)
    print("parent ends")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument("--n_classes", type=int, help="the number of classes")
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument("--num_negs", type=int, default=1)
    parser.add_argument(
        "--remove_edge",
        default=False,
        action="store_true",
        help="whether to remove edges during sampling",
    )
    args = parser.parse_args()
    print(args)
    main(args)
