import argparse
import math
import os
import random
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from torch.nn import (
    BCEWithLogitsLoss,
    Conv1d,
    Embedding,
    Linear,
    MaxPool1d,
    ModuleList,
)
from tqdm import tqdm

import dgl
from dgl.dataloading import DataLoader, Sampler
from dgl.nn import GraphConv, SortPooling
from dgl.sampling import global_uniform_negative_sampling


class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        # result is in the format of (val_score, test_score)
        assert len(result) == 2
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None, f=sys.stdout):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 0].argmax().item()
            print(f"Run {run + 1:02d}:", file=f)
            print(f"Highest Valid: {result[:, 0].max():.2f}", file=f)
            print(f"Highest Eval Point: {argmax + 1}", file=f)
            print(f"   Final Test: {result[argmax, 1]:.2f}", file=f)
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                valid = r[:, 0].max().item()
                test = r[r[:, 0].argmax(), 1].item()
                best_results.append((valid, test))

            best_result = torch.tensor(best_results)

            print(f"All runs:", file=f)
            r = best_result[:, 0]
            print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}", file=f)
            r = best_result[:, 1]
            print(f"   Final Test: {r.mean():.2f} ± {r.std():.2f}", file=f)


class SealSampler(Sampler):
    def __init__(
        self,
        g,
        num_hops=1,
        sample_ratio=1.0,
        directed=False,
        prefetch_node_feats=None,
        prefetch_edge_feats=None,
    ):
        super().__init__()
        self.g = g
        self.num_hops = num_hops
        self.sample_ratio = sample_ratio
        self.directed = directed
        self.prefetch_node_feats = prefetch_node_feats
        self.prefetch_edge_feats = prefetch_edge_feats

    def _double_radius_node_labeling(self, adj):
        N = adj.shape[0]
        adj_wo_src = adj[range(1, N), :][:, range(1, N)]
        idx = list(range(1)) + list(range(2, N))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(
            adj_wo_dst, directed=False, unweighted=True, indices=0
        )
        dist2src = np.insert(dist2src, 1, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(
            adj_wo_src, directed=False, unweighted=True, indices=0
        )
        dist2dst = np.insert(dist2dst, 0, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = (
            torch.div(dist, 2, rounding_mode="floor"),
            dist % 2,
        )

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[0:2] = 1.0
        # shortest path may include inf values
        z[torch.isnan(z)] = 0.0

        return z.to(torch.long)

    def sample(self, aug_g, seed_edges):
        g = self.g
        subgraphs = []
        # construct k-hop enclosing graph for each link
        for eid in seed_edges:
            src, dst = map(int, aug_g.find_edges(eid))
            # construct the enclosing graph
            visited, nodes, fringe = [np.unique([src, dst]) for _ in range(3)]
            for _ in range(self.num_hops):
                if not self.directed:
                    _, fringe = g.out_edges(fringe)
                else:
                    _, out_neighbors = g.out_edges(fringe)
                    in_neighbors, _ = g.in_edges(fringe)
                    fringe = np.union1d(in_neighbors, out_neighbors)
                fringe = np.setdiff1d(fringe, visited)
                visited = np.union1d(visited, fringe)
                if self.sample_ratio < 1.0:
                    fringe = np.random.choice(
                        fringe,
                        int(self.sample_ratio * len(fringe)),
                        replace=False,
                    )
                if len(fringe) == 0:
                    break
                nodes = np.union1d(nodes, fringe)
            subg = g.subgraph(nodes, store_ids=True)

            # remove edges to predict
            edges_to_remove = [
                subg.edge_ids(s, t)
                for s, t in [(0, 1), (1, 0)]
                if subg.has_edges_between(s, t)
            ]
            subg.remove_edges(edges_to_remove)
            # add double radius node labeling
            subg.ndata["z"] = self._double_radius_node_labeling(
                subg.adj(scipy_fmt="csr")
            )
            subg_aug = subg.add_self_loop()
            if "weight" in subg.edata:
                subg_aug.edata["weight"][subg.num_edges() :] = torch.ones(
                    subg_aug.num_edges() - subg.num_edges()
                )
            subgraphs.append(subg_aug)

        subgraphs = dgl.batch(subgraphs)
        dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats)
        dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats)

        return subgraphs, aug_g.edata["y"][seed_edges]


# An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module):
    def __init__(
        self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0
    ):
        super(DGCNN, self).__init__()
        self.feature_dim = feature_dim
        self.k = k
        self.sort_pool = SortPooling(k=k)

        self.max_z = 1000
        self.z_embedding = Embedding(self.max_z, hidden_channels)

        self.convs = ModuleList()
        initial_channels = hidden_channels + self.feature_dim

        self.convs.append(GNN(initial_channels, hidden_channels))
        for _ in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(
            conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
        )
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.lin1 = Linear(dense_dim, 128)
        self.lin2 = Linear(128, 1)

    def forward(self, g, z, x=None, edge_weight=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        xs = [x]

        for conv in self.convs:
            xs += [torch.tanh(conv(g, xs[-1], edge_weight=edge_weight))]
        x = torch.cat(xs[1:], dim=-1)

        # global pooling
        x = self.sort_pool(g, x)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x


def get_pos_neg_edges(split, split_edge, g, percent=100):
    pos_edge = split_edge[split]["edge"]
    if split == "train":
        neg_edge = torch.stack(
            global_uniform_negative_sampling(
                g, num_samples=pos_edge.size(0), exclude_self_loops=True
            ),
            dim=1,
        )
    else:
        neg_edge = split_edge[split]["edge_neg"]

    # sampling according to the percent param
    np.random.seed(123)
    # pos sampling
    num_pos = pos_edge.size(0)
    perm = np.random.permutation(num_pos)
    perm = perm[: int(percent / 100 * num_pos)]
    pos_edge = pos_edge[perm]
    # neg sampling
    if neg_edge.dim() > 2:  # [Np, Nn, 2]
        neg_edge = neg_edge[perm].view(-1, 2)
    else:
        np.random.seed(123)
        num_neg = neg_edge.size(0)
        perm = np.random.permutation(num_neg)
        perm = perm[: int(percent / 100 * num_neg)]
        neg_edge = neg_edge[perm]

    return pos_edge, neg_edge  # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])


def train():
    model.train()
    loss_fnt = BCEWithLogitsLoss()
    total_loss = 0
    total = 0
    pbar = tqdm(train_loader, ncols=70)
    for gs, y in pbar:
        optimizer.zero_grad()
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
        loss = loss_fnt(logits.view(-1), y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * gs.batch_size
        total += gs.batch_size

    return total_loss / total


@torch.no_grad()
def test():
    model.eval()

    y_pred, y_true = [], []
    for gs, y in tqdm(val_loader, ncols=70):
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
        y_pred.append(logits.view(-1).cpu())
        y_true.append(y.view(-1).cpu().to(torch.float))
    val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
    pos_val_pred = val_pred[val_true == 1]
    neg_val_pred = val_pred[val_true == 0]

    y_pred, y_true = [], []
    for gs, y in tqdm(test_loader, ncols=70):
        logits = model(
            gs,
            gs.ndata["z"],
            gs.ndata.get("feat", None),
            edge_weight=gs.edata.get("weight", None),
        )
        y_pred.append(logits.view(-1).cpu())
        y_true.append(y.view(-1).cpu().to(torch.float))
    test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
    pos_test_pred = test_pred[test_true == 1]
    neg_test_pred = test_pred[test_true == 0]

    if args.eval_metric == "hits":
        results = evaluate_hits(
            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
        )
    elif args.eval_metric == "mrr":
        results = evaluate_mrr(
            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred
        )

    return results


def evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        valid_hits = evaluator.eval(
            {
                "y_pred_pos": pos_val_pred,
                "y_pred_neg": neg_val_pred,
            }
        )[f"hits@{K}"]
        test_hits = evaluator.eval(
            {
                "y_pred_pos": pos_test_pred,
                "y_pred_neg": neg_test_pred,
            }
        )[f"hits@{K}"]

        results[f"Hits@{K}"] = (valid_hits, test_hits)

    return results


def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
    print(
        pos_val_pred.size(),
        neg_val_pred.size(),
        pos_test_pred.size(),
        neg_test_pred.size(),
    )
    neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1)
    neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)
    results = {}
    valid_mrr = (
        evaluator.eval(
            {
                "y_pred_pos": pos_val_pred,
                "y_pred_neg": neg_val_pred,
            }
        )["mrr_list"]
        .mean()
        .item()
    )

    test_mrr = (
        evaluator.eval(
            {
                "y_pred_pos": pos_test_pred,
                "y_pred_neg": neg_test_pred,
            }
        )["mrr_list"]
        .mean()
        .item()
    )

    results["MRR"] = (valid_mrr, test_mrr)

    return results


if __name__ == "__main__":
    # Data settings
    parser = argparse.ArgumentParser(description="OGBL (SEAL)")
    parser.add_argument("--dataset", type=str, default="ogbl-collab")
    # GNN settings
    parser.add_argument("--sortpool_k", type=float, default=0.6)
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=32)
    # Subgraph extraction settings
    parser.add_argument("--ratio_per_hop", type=float, default=1.0)
    parser.add_argument(
        "--use_feature",
        action="store_true",
        help="whether to use raw node features as GNN input",
    )
    parser.add_argument(
        "--use_edge_weight",
        action="store_true",
        help="whether to consider edge weight in GNN",
    )
    # Training settings
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--train_percent", type=float, default=100)
    parser.add_argument("--val_percent", type=float, default=100)
    parser.add_argument("--test_percent", type=float, default=100)
    parser.add_argument(
        "--num_workers",
        type=int,
        default=8,
        help="number of workers for dynamic dataloaders",
    )
    # Testing settings
    parser.add_argument("--use_valedges_as_input", action="store_true")
    parser.add_argument("--eval_steps", type=int, default=1)
    args = parser.parse_args()

    data_appendix = "_rph{}".format("".join(str(args.ratio_per_hop).split(".")))
    if args.use_valedges_as_input:
        data_appendix += "_uvai"

    args.res_dir = os.path.join(
        "results/{}_{}".format(args.dataset, time.strftime("%Y%m%d%H%M%S"))
    )
    print("Results will be saved in " + args.res_dir)
    if not os.path.exists(args.res_dir):
        os.makedirs(args.res_dir)
    log_file = os.path.join(args.res_dir, "log.txt")
    # Save command line input.
    cmd_input = "python " + " ".join(sys.argv) + "\n"
    with open(os.path.join(args.res_dir, "cmd_input.txt"), "a") as f:
        f.write(cmd_input)
    print("Command line input: " + cmd_input + " is saved.")
    with open(log_file, "a") as f:
        f.write("\n" + cmd_input)

    dataset = DglLinkPropPredDataset(name=args.dataset)
    split_edge = dataset.get_edge_split()
    graph = dataset[0]

    # re-format the data of citation2
    if args.dataset == "ogbl-citation2":
        for k in ["train", "valid", "test"]:
            src = split_edge[k]["source_node"]
            tgt = split_edge[k]["target_node"]
            split_edge[k]["edge"] = torch.stack([src, tgt], dim=1)
            if k != "train":
                tgt_neg = split_edge[k]["target_node_neg"]
                split_edge[k]["edge_neg"] = torch.stack(
                    [src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1
                )  # [Ns, Nt, 2]

    # reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce
    if args.dataset == "ogbl-collab":
        if args.use_valedges_as_input:
            val_edges = split_edge["valid"]["edge"]
            row, col = val_edges.t()
            # float edata for to_simple transform
            graph.edata.pop("year")
            graph.edata["weight"] = graph.edata["weight"].to(torch.float)
            val_weights = torch.ones(size=(val_edges.size(0), 1))
            graph.add_edges(
                torch.cat([row, col]),
                torch.cat([col, row]),
                {"weight": val_weights},
            )
        graph = graph.to_simple(copy_edata=True, aggregator="sum")

    if not args.use_edge_weight and "weight" in graph.edata:
        graph.edata.pop("weight")
    if not args.use_feature and "feat" in graph.ndata:
        graph.ndata.pop("feat")

    if args.dataset.startswith("ogbl-citation"):
        args.eval_metric = "mrr"
        directed = True
    else:
        args.eval_metric = "hits"
        directed = False

    evaluator = Evaluator(name=args.dataset)
    if args.eval_metric == "hits":
        loggers = {
            "Hits@20": Logger(args.runs, args),
            "Hits@50": Logger(args.runs, args),
            "Hits@100": Logger(args.runs, args),
        }
    elif args.eval_metric == "mrr":
        loggers = {
            "MRR": Logger(args.runs, args),
        }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = dataset.root + "_seal{}".format(data_appendix)

    loaders = []
    prefetch_node_feats = ["feat"] if "feat" in graph.ndata else None
    prefetch_edge_feats = ["weight"] if "weight" in graph.edata else None

    train_edge, train_edge_neg = get_pos_neg_edges(
        "train", split_edge, graph, args.train_percent
    )
    val_edge, val_edge_neg = get_pos_neg_edges(
        "valid", split_edge, graph, args.val_percent
    )
    test_edge, test_edge_neg = get_pos_neg_edges(
        "test", split_edge, graph, args.test_percent
    )
    # create an augmented graph for sampling
    aug_g = dgl.graph(graph.edges())
    aug_g.edata["y"] = torch.ones(aug_g.num_edges())
    aug_edges = torch.cat(
        [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]
    )
    aug_labels = torch.cat(
        [
            torch.ones(len(val_edge) + len(test_edge)),
            torch.zeros(
                len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)
            ),
        ]
    )
    aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {"y": aug_labels})
    # eids for sampling
    split_len = [graph.num_edges()] + list(
        map(
            len,
            [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg],
        )
    )
    train_eids = torch.cat(
        [
            graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),
            torch.arange(sum(split_len[:3]), sum(split_len[:4])),
        ]
    )
    val_eids = torch.cat(
        [
            torch.arange(sum(split_len[:1]), sum(split_len[:2])),
            torch.arange(sum(split_len[:4]), sum(split_len[:5])),
        ]
    )
    test_eids = torch.cat(
        [
            torch.arange(sum(split_len[:2]), sum(split_len[:3])),
            torch.arange(sum(split_len[:5]), sum(split_len[:6])),
        ]
    )
    sampler = SealSampler(
        graph,
        1,
        args.ratio_per_hop,
        directed,
        prefetch_node_feats,
        prefetch_edge_feats,
    )
    # force to be dynamic for consistent dataloading
    for split, shuffle, eids in zip(
        ["train", "valid", "test"],
        [True, False, False],
        [train_eids, val_eids, test_eids],
    ):
        data_loader = DataLoader(
            aug_g,
            eids,
            sampler,
            shuffle=shuffle,
            device=device,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
        )
        loaders.append(data_loader)
    train_loader, val_loader, test_loader = loaders

    # convert sortpool_k from percentile to number.
    num_nodes = []
    for subgs, _ in train_loader:
        subgs = dgl.unbatch(subgs)
        if len(num_nodes) > 1000:
            break
        for subg in subgs:
            num_nodes.append(subg.num_nodes())
    num_nodes = sorted(num_nodes)
    k = num_nodes[int(math.ceil(args.sortpool_k * len(num_nodes))) - 1]
    k = max(k, 10)

    for run in range(args.runs):
        model = DGCNN(
            args.hidden_channels,
            args.num_layers,
            k,
            feature_dim=graph.ndata["feat"].size(1) if args.use_feature else 0,
        ).to(device)
        parameters = list(model.parameters())
        optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
        total_params = sum(p.numel() for param in parameters for p in param)
        print(f"Total number of parameters is {total_params}")
        print(f"SortPooling k is set to {k}")
        with open(log_file, "a") as f:
            print(f"Total number of parameters is {total_params}", file=f)
            print(f"SortPooling k is set to {k}", file=f)

        start_epoch = 1
        # Training starts
        for epoch in range(start_epoch, start_epoch + args.epochs):
            loss = train()

            if epoch % args.eval_steps == 0:
                results = test()
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                model_name = os.path.join(
                    args.res_dir,
                    "run{}_model_checkpoint{}.pth".format(run + 1, epoch),
                )
                optimizer_name = os.path.join(
                    args.res_dir,
                    "run{}_optimizer_checkpoint{}.pth".format(run + 1, epoch),
                )
                torch.save(model.state_dict(), model_name)
                torch.save(optimizer.state_dict(), optimizer_name)

                for key, result in results.items():
                    valid_res, test_res = result
                    to_print = (
                        f"Run: {run + 1:02d}, Epoch: {epoch:02d}, "
                        + f"Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, "
                        + f"Test: {100 * test_res:.2f}%"
                    )
                    print(key)
                    print(to_print)
                    with open(log_file, "a") as f:
                        print(key, file=f)
                        print(to_print, file=f)

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)
            with open(log_file, "a") as f:
                print(key, file=f)
                loggers[key].print_statistics(run, f=f)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
        with open(log_file, "a") as f:
            print(key, file=f)
            loggers[key].print_statistics(f=f)
    print(f"Total number of parameters is {total_params}")
    print(f"Results are saved in {args.res_dir}")
