import argparse

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from tqdm import tqdm

import dgl
import dgl.nn as dglnn
from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader


class MLP(nn.Module):
    def __init__(self, in_feats):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_feats, 2 * in_feats),
            nn.BatchNorm1d(2 * in_feats),
            nn.ReLU(),
            nn.Linear(2 * in_feats, in_feats),
            nn.BatchNorm1d(in_feats),
        )

    def forward(self, h):
        return self.mlp(h)


class GIN(nn.Module):
    def __init__(self, n_hidden, n_output, n_layers=5):
        super().__init__()
        self.node_encoder = AtomEncoder(n_hidden)
        self.edge_encoders = nn.ModuleList(
            [BondEncoder(n_hidden) for _ in range(n_layers)]
        )

        self.pool = dglnn.AvgPooling()
        self.dropout = nn.Dropout(0.5)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(dglnn.GINEConv(MLP(n_hidden), learn_eps=True))
        self.predictor = nn.Linear(n_hidden, n_output)

        # add virtual node
        self.virtual_emb = nn.Embedding(1, n_hidden)
        nn.init.constant_(self.virtual_emb.weight.data, 0)
        self.virtual_layers = nn.ModuleList()
        for _ in range(n_layers - 1):
            self.virtual_layers.append(MLP(n_hidden))
        self.virtual_pool = dglnn.SumPooling()

    def forward(self, g, x, x_e):
        v_emb = self.virtual_emb.weight.expand(g.batch_size, -1)
        hn = self.node_encoder(x)
        for i in range(len(self.layers)):
            v_hn = dgl.broadcast_nodes(g, v_emb)
            hn = hn + v_hn
            he = self.edge_encoders[i](x_e)
            hn = self.layers[i](g, hn, he)
            hn = F.relu(hn)
            hn = self.dropout(hn)
            if i != len(self.layers) - 1:
                v_emb_tmp = self.virtual_pool(g, hn) + v_emb
                v_emb = self.virtual_layers[i](v_emb_tmp)
                v_emb = self.dropout(F.relu(v_emb))
        hn = self.pool(g, hn)
        return self.predictor(hn)


@torch.no_grad()
def evaluate(dataloader, device, model, evaluator):
    model.eval()
    y_true = []
    y_pred = []
    for batched_graph, labels in tqdm(dataloader):
        batched_graph, labels = batched_graph.to(device), labels.to(device)
        node_feat, edge_feat = (
            batched_graph.ndata["feat"],
            batched_graph.edata["feat"],
        )
        y_hat = model(batched_graph, node_feat, edge_feat)
        y_true.append(labels.view(y_hat.shape).detach().cpu())
        y_pred.append(y_hat.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)


def train(rank, world_size, dataset_name, root):
    dist.init_process_group(
        "nccl", "tcp://127.0.0.1:12347", world_size=world_size, rank=rank
    )
    torch.cuda.set_device(rank)

    dataset = AsGraphPredDataset(DglGraphPropPredDataset(dataset_name, root))
    evaluator = Evaluator(dataset_name)

    model = GIN(300, dataset.num_tasks).to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    train_dataloader = GraphDataLoader(
        dataset[dataset.train_idx], batch_size=256, use_ddp=True, shuffle=True
    )
    valid_dataloader = GraphDataLoader(dataset[dataset.val_idx], batch_size=256)
    test_dataloader = GraphDataLoader(dataset[dataset.test_idx], batch_size=256)

    for epoch in range(50):
        model.train()
        train_dataloader.set_epoch(epoch)
        for batched_graph, labels in train_dataloader:
            batched_graph, labels = batched_graph.to(rank), labels.to(rank)
            node_feat, edge_feat = (
                batched_graph.ndata["feat"],
                batched_graph.edata["feat"],
            )
            logits = model(batched_graph, node_feat, edge_feat)
            optimizer.zero_grad()
            is_labeled = labels == labels
            loss = F.binary_cross_entropy_with_logits(
                logits.float()[is_labeled], labels.float()[is_labeled]
            )
            loss.backward()
            optimizer.step()
        scheduler.step()

        if rank == 0:
            val_metric = evaluate(
                valid_dataloader, rank, model.module, evaluator
            )[evaluator.eval_metric]
            test_metric = evaluate(
                test_dataloader, rank, model.module, evaluator
            )[evaluator.eval_metric]

            print(
                f"Epoch: {epoch:03d}, Loss: {loss:.4f}, "
                f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
            )

    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="ogbg-molhiv",
        choices=["ogbg-molhiv", "ogbg-molpcba"],
        help="name of dataset (default: ogbg-molhiv)",
    )
    dataset_name = parser.parse_args().dataset
    root = "./data/OGB"
    DglGraphPropPredDataset(dataset_name, root)

    world_size = torch.cuda.device_count()
    print("Let's use", world_size, "GPUs!")
    args = (world_size, dataset_name, root)
    import torch.multiprocessing as mp

    mp.spawn(train, args=args, nprocs=world_size, join=True)
