"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415

Author's implementation: https://github.com/joanbruna/GNN_community
"""

from __future__ import division

import argparse
import time
from itertools import permutations

import gnn
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from dgl.data import SBMMixtureDataset

parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument(
    "--n-communities", type=int, help="Number of communities", default=2
)
parser.add_argument(
    "--n-epochs", type=int, help="Number of epochs", default=100
)
parser.add_argument(
    "--n-features", type=int, help="Number of features", default=16
)
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
    "--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()

dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
K = args.n_communities

training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(
    training_dataset,
    args.batch_size,
    collate_fn=training_dataset.collate_fn,
    drop_last=True,
)

ones = th.ones(args.n_nodes // K)
y_list = [
    th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)


def compute_overlap(z_list):
    ybar_list = [th.max(z, 1)[1] for z in z_list]
    overlap_list = []
    for y_bar in ybar_list:
        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
        overlap = (accuracy - 1 / K) / (1 - 1 / K)
        overlap_list.append(overlap)
    return sum(overlap_list) / len(overlap_list)


def from_np(f, *args):
    def wrap(*args):
        new = [
            th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
        ]
        return f(*new)

    return wrap


@from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
    """One step of training."""
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
    pm_pd = pm_pd.to(dev)
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0

    z_list = th.chunk(z, args.batch_size, 0)
    loss = (
        sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
        / args.batch_size
    )
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
    loss.backward()
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward


@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
    g = g.to(dev)
    lg = lg.to(dev)
    deg_g = deg_g.to(dev).unsqueeze(1)
    deg_lg = deg_lg.to(dev).unsqueeze(1)
    pm_pd = pm_pd.to(dev)

    z = model(g, lg, deg_g, deg_lg, pm_pd)

    return z


def test():
    p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
    q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
    N = 1
    overlap_list = []
    for p, q in zip(p_list, q_list):
        dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)
        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
        z = inference(g, lg, deg_g, deg_lg, pm_pd)
        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
    return overlap_list


n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
        loss, overlap, t_forward, t_backward = step(
            i, j, g, lg, deg_g, deg_lg, pm_pd
        )

        total_loss += loss
        total_overlap += overlap
        s_forward += t_forward
        s_backward += t_backward

        epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
        iteration = "0" * (len(str(n_iterations)) - len(str(j)))
        if args.verbose:
            print(
                "[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
                % (epoch, i, iteration, j, loss, overlap)
            )

    epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
    loss = total_loss / (j + 1)
    overlap = total_overlap / (j + 1)
    t_forward = s_forward / (j + 1)
    t_backward = s_backward / (j + 1)
    print(
        "[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
        % (epoch, i, loss, overlap, t_forward, t_backward)
    )

    overlap_list = test()
    overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
    print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))
