import argparse
import math

import torch
import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from torch.nn import Linear
from torch.utils.data import DataLoader

import dgl
from dgl.dataloading.negative_sampler import GlobalUniform
from dgl.nn.pytorch import GraphConv, SAGEConv


class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f"Run {run + 1:02d}:")
            print(f"Highest Train: {result[:, 0].max():.2f}")
            print(f"Highest Valid: {result[:, 1].max():.2f}")
            print(f"  Final Train: {result[argmax, 0]:.2f}")
            print(f"   Final Test: {result[argmax, 2]:.2f}")
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)

            print(f"All runs:")
            r = best_result[:, 0]
            print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
            r = best_result[:, 1]
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
            r = best_result[:, 2]
            print(f"  Final Train: {r.mean():.2f} ± {r.std():.2f}")
            r = best_result[:, 3]
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}")


class NGNN_GCNConv(torch.nn.Module):
    def __init__(
        self, in_channels, hidden_channels, out_channels, num_nonl_layers
    ):
        super(NGNN_GCNConv, self).__init__()
        self.num_nonl_layers = (
            num_nonl_layers  # number of nonlinear layers in each conv layer
        )
        self.conv = GraphConv(in_channels, hidden_channels)
        self.fc = Linear(hidden_channels, hidden_channels)
        self.fc2 = Linear(hidden_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv.reset_parameters()
        gain = torch.nn.init.calculate_gain("relu")
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
        for bias in [self.fc.bias, self.fc2.bias]:
            stdv = 1.0 / math.sqrt(bias.size(0))
            bias.data.uniform_(-stdv, stdv)

    def forward(self, g, x):
        x = self.conv(g, x)

        if self.num_nonl_layers == 2:
            x = F.relu(x)
            x = self.fc(x)

        x = F.relu(x)
        x = self.fc2(x)
        return x


class GCN(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_layers,
        dropout,
        ngnn_type,
        dataset,
    ):
        super(GCN, self).__init__()

        self.dataset = dataset
        self.convs = torch.nn.ModuleList()

        num_nonl_layers = (
            1 if num_layers <= 2 else 2
        )  # number of nonlinear layers in each conv layer
        if ngnn_type == "input":
            self.convs.append(
                NGNN_GCNConv(
                    in_channels,
                    hidden_channels,
                    hidden_channels,
                    num_nonl_layers,
                )
            )
            for _ in range(num_layers - 2):
                self.convs.append(GraphConv(hidden_channels, hidden_channels))
        elif ngnn_type == "hidden":
            self.convs.append(GraphConv(in_channels, hidden_channels))
            for _ in range(num_layers - 2):
                self.convs.append(
                    NGNN_GCNConv(
                        hidden_channels,
                        hidden_channels,
                        hidden_channels,
                        num_nonl_layers,
                    )
                )

        self.convs.append(GraphConv(hidden_channels, out_channels))

        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, g, x):
        for conv in self.convs[:-1]:
            x = conv(g, x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](g, x)
        return x


class NGNN_SAGEConv(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_nonl_layers,
        *,
        reduce,
    ):
        super(NGNN_SAGEConv, self).__init__()
        self.num_nonl_layers = (
            num_nonl_layers  # number of nonlinear layers in each conv layer
        )
        self.conv = SAGEConv(in_channels, hidden_channels, reduce)
        self.fc = Linear(hidden_channels, hidden_channels)
        self.fc2 = Linear(hidden_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv.reset_parameters()
        gain = torch.nn.init.calculate_gain("relu")
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
        for bias in [self.fc.bias, self.fc2.bias]:
            stdv = 1.0 / math.sqrt(bias.size(0))
            bias.data.uniform_(-stdv, stdv)

    def forward(self, g, x):
        x = self.conv(g, x)

        if self.num_nonl_layers == 2:
            x = F.relu(x)
            x = self.fc(x)

        x = F.relu(x)
        x = self.fc2(x)
        return x


class SAGE(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_layers,
        dropout,
        ngnn_type,
        dataset,
        reduce="mean",
    ):
        super(SAGE, self).__init__()

        self.dataset = dataset
        self.convs = torch.nn.ModuleList()

        num_nonl_layers = (
            1 if num_layers <= 2 else 2
        )  # number of nonlinear layers in each conv layer
        if ngnn_type == "input":
            self.convs.append(
                NGNN_SAGEConv(
                    in_channels,
                    hidden_channels,
                    hidden_channels,
                    num_nonl_layers,
                    reduce=reduce,
                )
            )
            for _ in range(num_layers - 2):
                self.convs.append(
                    SAGEConv(hidden_channels, hidden_channels, reduce)
                )
        elif ngnn_type == "hidden":
            self.convs.append(SAGEConv(in_channels, hidden_channels, reduce))
            for _ in range(num_layers - 2):
                self.convs.append(
                    NGNN_SAGEConv(
                        hidden_channels,
                        hidden_channels,
                        hidden_channels,
                        num_nonl_layers,
                        reduce=reduce,
                    )
                )

        self.convs.append(SAGEConv(hidden_channels, out_channels, reduce))

        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, g, x):
        for conv in self.convs[:-1]:
            x = conv(g, x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](g, x)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(
        self, in_channels, hidden_channels, out_channels, num_layers, dropout
    ):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.lins.append(Linear(hidden_channels, out_channels))

        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


def train(model, predictor, g, x, split_edge, optimizer, batch_size):
    model.train()
    predictor.train()

    pos_train_edge = split_edge["train"]["edge"].to(x.device)
    neg_sampler = GlobalUniform(1)
    total_loss = total_examples = 0
    for perm in DataLoader(
        range(pos_train_edge.size(0)), batch_size, shuffle=True
    ):
        optimizer.zero_grad()

        h = model(g, x)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = neg_sampler(g, edge[0])

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        if model.dataset == "ogbl-ddi":
            torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


@torch.no_grad()
def test(model, predictor, g, x, split_edge, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(g, x)

    pos_train_edge = split_edge["eval_train"]["edge"].to(h.device)
    pos_valid_edge = split_edge["valid"]["edge"].to(h.device)
    neg_valid_edge = split_edge["valid"]["edge_neg"].to(h.device)
    pos_test_edge = split_edge["test"]["edge"].to(h.device)
    neg_test_edge = split_edge["test"]["edge_neg"].to(h.device)

    def get_pred(test_edges, h):
        preds = []
        for perm in DataLoader(range(test_edges.size(0)), batch_size):
            edge = test_edges[perm].t()
            preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
        pred = torch.cat(preds, dim=0)
        return pred

    pos_train_pred = get_pred(pos_train_edge, h)
    pos_valid_pred = get_pred(pos_valid_edge, h)
    neg_valid_pred = get_pred(neg_valid_edge, h)
    pos_test_pred = get_pred(pos_test_edge, h)
    neg_test_pred = get_pred(neg_test_edge, h)

    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        train_hits = evaluator.eval(
            {
                "y_pred_pos": pos_train_pred,
                "y_pred_neg": neg_valid_pred,
            }
        )[f"hits@{K}"]
        valid_hits = evaluator.eval(
            {
                "y_pred_pos": pos_valid_pred,
                "y_pred_neg": neg_valid_pred,
            }
        )[f"hits@{K}"]
        test_hits = evaluator.eval(
            {
                "y_pred_pos": pos_test_pred,
                "y_pred_neg": neg_test_pred,
            }
        )[f"hits@{K}"]

        results[f"Hits@{K}"] = (train_hits, valid_hits, test_hits)

    return results


def main():
    parser = argparse.ArgumentParser(
        description="OGBL(Full Batch GCN/GraphSage + NGNN)"
    )

    # dataset setting
    parser.add_argument(
        "--dataset",
        type=str,
        default="ogbl-ddi",
        choices=["ogbl-ddi", "ogbl-collab", "ogbl-ppa"],
    )

    # device setting
    parser.add_argument(
        "--device",
        type=int,
        default=0,
        help="GPU device ID. Use -1 for CPU training.",
    )

    # model structure settings
    parser.add_argument(
        "--use_sage",
        action="store_true",
        help="If not set, use GCN by default.",
    )
    parser.add_argument(
        "--ngnn_type",
        type=str,
        default="input",
        choices=["input", "hidden"],
        help="You can set this value from 'input' or 'hidden' to apply NGNN to different GNN layers.",
    )
    parser.add_argument(
        "--num_layers", type=int, default=3, help="number of GNN layers"
    )
    parser.add_argument("--hidden_channels", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=64 * 1024)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--epochs", type=int, default=400)

    # training settings
    parser.add_argument("--eval_steps", type=int, default=1)
    parser.add_argument("--runs", type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = (
        f"cuda:{args.device}"
        if args.device != -1 and torch.cuda.is_available()
        else "cpu"
    )
    device = torch.device(device)

    dataset = DglLinkPropPredDataset(name=args.dataset)
    g = dataset[0]
    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    idx = torch.randperm(split_edge["train"]["edge"].size(0))
    idx = idx[: split_edge["valid"]["edge"].size(0)]
    split_edge["eval_train"] = {"edge": split_edge["train"]["edge"][idx]}

    if dataset.name == "ogbl-ppa":
        g.ndata["feat"] = g.ndata["feat"].to(torch.float)

    if dataset.name == "ogbl-ddi":
        emb = torch.nn.Embedding(g.num_nodes(), args.hidden_channels).to(device)
        in_channels = args.hidden_channels
    else:  # ogbl-collab, ogbl-ppa
        in_channels = g.ndata["feat"].size(-1)

    # select model
    if args.use_sage:
        model = SAGE(
            in_channels,
            args.hidden_channels,
            args.hidden_channels,
            args.num_layers,
            args.dropout,
            args.ngnn_type,
            dataset.name,
        )
    else:  # GCN
        g = dgl.add_self_loop(g)
        model = GCN(
            in_channels,
            args.hidden_channels,
            args.hidden_channels,
            args.num_layers,
            args.dropout,
            args.ngnn_type,
            dataset.name,
        )

    predictor = LinkPredictor(
        args.hidden_channels, args.hidden_channels, 1, 3, args.dropout
    )

    g, model, predictor = map(lambda x: x.to(device), (g, model, predictor))

    evaluator = Evaluator(name=dataset.name)
    loggers = {
        "Hits@20": Logger(args.runs, args),
        "Hits@50": Logger(args.runs, args),
        "Hits@100": Logger(args.runs, args),
    }

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        if dataset.name == "ogbl-ddi":
            torch.nn.init.xavier_uniform_(emb.weight)
            g.ndata["feat"] = emb.weight
        optimizer = torch.optim.Adam(
            list(model.parameters())
            + list(predictor.parameters())
            + (list(emb.parameters()) if dataset.name == "ogbl-ddi" else []),
            lr=args.lr,
        )
        for epoch in range(1, 1 + args.epochs):
            loss = train(
                model,
                predictor,
                g,
                g.ndata["feat"],
                split_edge,
                optimizer,
                args.batch_size,
            )

            if epoch % args.eval_steps == 0:
                results = test(
                    model,
                    predictor,
                    g,
                    g.ndata["feat"],
                    split_edge,
                    evaluator,
                    args.batch_size,
                )
                for key, result in results.items():
                    loggers[key].add_result(run, result)
                    train_hits, valid_hits, test_hits = result
                    print(key)
                    print(
                        f"Run: {run + 1:02d}, "
                        f"Epoch: {epoch:02d}, "
                        f"Loss: {loss:.4f}, "
                        f"Train: {100 * train_hits:.2f}%, "
                        f"Valid: {100 * valid_hits:.2f}%, "
                        f"Test: {100 * test_hits:.2f}%"
                    )
                print("---")

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()


if __name__ == "__main__":
    main()
