#!/usr/bin/env python
# coding: utf-8

import argparse
import time

import numpy as np
import ogb
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ogb.lsc import MAG240MDataset, MAG240MEvaluator

import dgl
import dgl.function as fn
import dgl.nn as dglnn


class RGAT(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        hidden_channels,
        num_etypes,
        num_layers,
        num_heads,
        dropout,
        pred_ntype,
    ):
        super().__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.skips = nn.ModuleList()

        self.convs.append(
            nn.ModuleList(
                [
                    dglnn.GATConv(
                        in_channels,
                        hidden_channels // num_heads,
                        num_heads,
                        allow_zero_in_degree=True,
                    )
                    for _ in range(num_etypes)
                ]
            )
        )
        self.norms.append(nn.BatchNorm1d(hidden_channels))
        self.skips.append(nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(
                nn.ModuleList(
                    [
                        dglnn.GATConv(
                            hidden_channels,
                            hidden_channels // num_heads,
                            num_heads,
                            allow_zero_in_degree=True,
                        )
                        for _ in range(num_etypes)
                    ]
                )
            )
            self.norms.append(nn.BatchNorm1d(hidden_channels))
            self.skips.append(nn.Linear(hidden_channels, hidden_channels))

        self.mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, out_channels),
        )
        self.dropout = nn.Dropout(dropout)

        self.hidden_channels = hidden_channels
        self.pred_ntype = pred_ntype
        self.num_etypes = num_etypes

    def forward(self, mfgs, x):
        for i in range(len(mfgs)):
            mfg = mfgs[i]
            x_dst = x[: mfg.num_dst_nodes()]
            n_src = mfg.num_src_nodes()
            n_dst = mfg.num_dst_nodes()
            mfg = dgl.block_to_graph(mfg)
            x_skip = self.skips[i](x_dst)
            for j in range(self.num_etypes):
                subg = mfg.edge_subgraph(
                    mfg.edata["etype"] == j, relabel_nodes=False
                )
                x_skip += self.convs[i][j](subg, (x, x_dst)).view(
                    -1, self.hidden_channels
                )
            x = self.norms[i](x_skip)
            x = F.elu(x)
            x = self.dropout(x)
        return self.mlp(x)


class ExternalNodeCollator(dgl.dataloading.NodeCollator):
    def __init__(self, g, idx, sampler, offset, feats, label):
        super().__init__(g, idx, sampler)
        self.offset = offset
        self.feats = feats
        self.label = label

    def collate(self, items):
        input_nodes, output_nodes, mfgs = super().collate(items)
        # Copy input features
        mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
        mfgs[-1].dstdata["y"] = torch.LongTensor(
            self.label[output_nodes - self.offset]
        )
        return input_nodes, output_nodes, mfgs


def train(args, dataset, g, feats, paper_offset):
    print("Loading masks and labels")
    train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
    valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
    label = dataset.paper_label

    print("Initializing dataloader...")
    sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
    train_collator = ExternalNodeCollator(
        g, train_idx, sampler, paper_offset, feats, label
    )
    valid_collator = ExternalNodeCollator(
        g, valid_idx, sampler, paper_offset, feats, label
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_collator.dataset,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        collate_fn=train_collator.collate,
        num_workers=4,
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        collate_fn=valid_collator.collate,
        num_workers=2,
    )

    print("Initializing model...")
    model = RGAT(
        dataset.num_paper_features,
        dataset.num_classes,
        1024,
        5,
        2,
        4,
        0.5,
        "paper",
    ).cuda()
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)

    best_acc = 0

    for _ in range(args.epochs):
        model.train()
        with tqdm.tqdm(train_dataloader) as tq:
            for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
                mfgs = [g.to("cuda") for g in mfgs]
                x = mfgs[0].srcdata["x"]
                y = mfgs[-1].dstdata["y"]
                y_hat = model(mfgs, x)
                loss = F.cross_entropy(y_hat, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                acc = (y_hat.argmax(1) == y).float().mean()
                tq.set_postfix(
                    {"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
                    refresh=False,
                )

        model.eval()
        correct = total = 0
        for i, (input_nodes, output_nodes, mfgs) in enumerate(
            tqdm.tqdm(valid_dataloader)
        ):
            with torch.no_grad():
                mfgs = [g.to("cuda") for g in mfgs]
                x = mfgs[0].srcdata["x"]
                y = mfgs[-1].dstdata["y"]
                y_hat = model(mfgs, x)
                correct += (y_hat.argmax(1) == y).sum().item()
                total += y_hat.shape[0]
        acc = correct / total
        print("Validation accuracy:", acc)

        sched.step()

        if best_acc < acc:
            best_acc = acc
            print("Updating best model...")
            torch.save(model.state_dict(), args.model_path)


def test(args, dataset, g, feats, paper_offset):
    print("Loading masks and labels...")
    valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
    test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
    label = dataset.paper_label

    print("Initializing data loader...")
    sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
    valid_collator = ExternalNodeCollator(
        g, valid_idx, sampler, paper_offset, feats, label
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=valid_collator.collate,
        num_workers=2,
    )
    test_collator = ExternalNodeCollator(
        g, test_idx, sampler, paper_offset, feats, label
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=test_collator.collate,
        num_workers=4,
    )

    print("Loading model...")
    model = RGAT(
        dataset.num_paper_features,
        dataset.num_classes,
        1024,
        5,
        2,
        4,
        0.5,
        "paper",
    ).cuda()
    model.load_state_dict(torch.load(args.model_path))

    model.eval()
    correct = total = 0
    for i, (input_nodes, output_nodes, mfgs) in enumerate(
        tqdm.tqdm(valid_dataloader)
    ):
        with torch.no_grad():
            mfgs = [g.to("cuda") for g in mfgs]
            x = mfgs[0].srcdata["x"]
            y = mfgs[-1].dstdata["y"]
            y_hat = model(mfgs, x)
            correct += (y_hat.argmax(1) == y).sum().item()
            total += y_hat.shape[0]
    acc = correct / total
    print("Validation accuracy:", acc)
    evaluator = MAG240MEvaluator()
    y_preds = []
    for i, (input_nodes, output_nodes, mfgs) in enumerate(
        tqdm.tqdm(test_dataloader)
    ):
        with torch.no_grad():
            mfgs = [g.to("cuda") for g in mfgs]
            x = mfgs[0].srcdata["x"]
            y = mfgs[-1].dstdata["y"]
            y_hat = model(mfgs, x)
            y_preds.append(y_hat.argmax(1).cpu())
    evaluator.save_test_submission(
        {"y_pred": torch.cat(y_preds)}, args.submission_path
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--rootdir",
        type=str,
        default=".",
        help="Directory to download the OGB dataset.",
    )
    parser.add_argument(
        "--graph-path",
        type=str,
        default="./graph.dgl",
        help="Path to the graph.",
    )
    parser.add_argument(
        "--full-feature-path",
        type=str,
        default="./full.npy",
        help="Path to the features of all nodes.",
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs."
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="./model.pt",
        help="Path to store the best model.",
    )
    parser.add_argument(
        "--submission-path",
        type=str,
        default="./results",
        help="Submission directory.",
    )
    args = parser.parse_args()

    dataset = MAG240MDataset(root=args.rootdir)

    print("Loading graph")
    (g,), _ = dgl.load_graphs(args.graph_path)
    g = g.formats(["csc"])

    print("Loading features")
    paper_offset = dataset.num_authors + dataset.num_institutions
    num_nodes = paper_offset + dataset.num_papers
    num_features = dataset.num_paper_features
    feats = np.memmap(
        args.full_feature_path,
        mode="r",
        dtype="float16",
        shape=(num_nodes, num_features),
    )

    if args.epochs != 0:
        train(args, dataset, g, feats, paper_offset)
    test(args, dataset, g, feats, paper_offset)
