import argparse
import time

import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from coarsening import coarsen
from coordinate import get_coordinates, z2polar
from grid_graph import grid_graph
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import dgl
from dgl.data import load_data, register_data_args
from dgl.nn.pytorch.conv import ChebConv, GMMConv
from dgl.nn.pytorch.glob import MaxPooling

argparser = argparse.ArgumentParser("MNIST")
argparser.add_argument(
    "--gpu", type=int, default=-1, help="gpu id, use cpu if set to -1"
)
argparser.add_argument(
    "--model", type=str, default="chebnet", help="model to use, chebnet/monet"
)
argparser.add_argument("--batch-size", type=int, default=100, help="batch size")
args = argparser.parse_args()

grid_side = 28
number_edges = 8
metric = "euclidean"

A = grid_graph(28, 8, metric)

coarsening_levels = 4
L, perm = coarsen(A, coarsening_levels)
g_arr = [dgl.from_scipy(csr) for csr in L]

coordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm)
str_to_torch_dtype = {
    "float16": torch.half,
    "float32": torch.float32,
    "float64": torch.float64,
}
coordinate_arr = [
    coord.to(dtype=str_to_torch_dtype[str(A.dtype)]) for coord in coordinate_arr
]
for g, coordinate_arr in zip(g_arr, coordinate_arr):
    g.ndata["xy"] = coordinate_arr
    g.apply_edges(z2polar)


def batcher(batch):
    g_batch = [[] for _ in range(coarsening_levels + 1)]
    x_batch = []
    y_batch = []
    for x, y in batch:
        x = torch.cat([x.view(-1), x.new_zeros(len(perm) - 28**2)], 0)
        x = x[perm]
        x_batch.append(x)
        y_batch.append(y)
        for i in range(coarsening_levels + 1):
            g_batch[i].append(g_arr[i])

    x_batch = torch.cat(x_batch).unsqueeze(-1)
    y_batch = torch.LongTensor(y_batch)
    g_batch = [dgl.batch(g) for g in g_batch]
    return g_batch, x_batch, y_batch


trainset = datasets.MNIST(
    root=".", train=True, download=True, transform=transforms.ToTensor()
)
testset = datasets.MNIST(
    root=".", train=False, download=True, transform=transforms.ToTensor()
)

train_loader = DataLoader(
    trainset,
    batch_size=args.batch_size,
    shuffle=True,
    collate_fn=batcher,
    num_workers=6,
)
test_loader = DataLoader(
    testset,
    batch_size=args.batch_size,
    shuffle=False,
    collate_fn=batcher,
    num_workers=6,
)


class MoNet(nn.Module):
    def __init__(self, n_kernels, in_feats, hiddens, out_feats):
        super(MoNet, self).__init__()
        self.pool = nn.MaxPool1d(2)
        self.layers = nn.ModuleList()
        self.readout = MaxPooling()

        # Input layer
        self.layers.append(GMMConv(in_feats, hiddens[0], 2, n_kernels))

        # Hidden layer
        for i in range(1, len(hiddens)):
            self.layers.append(
                GMMConv(hiddens[i - 1], hiddens[i], 2, n_kernels)
            )

        self.cls = nn.Sequential(
            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax()
        )

    def forward(self, g_arr, feat):
        for g, layer in zip(g_arr, self.layers):
            u = g.edata["u"]
            feat = (
                self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))
                .squeeze(0)
                .transpose(-1, -2)
            )
        return self.cls(self.readout(g_arr[-1], feat))


class ChebNet(nn.Module):
    def __init__(self, k, in_feats, hiddens, out_feats):
        super(ChebNet, self).__init__()
        self.pool = nn.MaxPool1d(2)
        self.layers = nn.ModuleList()
        self.readout = MaxPooling()

        # Input layer
        self.layers.append(ChebConv(in_feats, hiddens[0], k))

        for i in range(1, len(hiddens)):
            self.layers.append(ChebConv(hiddens[i - 1], hiddens[i], k))

        self.cls = nn.Sequential(
            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax()
        )

    def forward(self, g_arr, feat):
        for g, layer in zip(g_arr, self.layers):
            feat = (
                self.pool(
                    layer(g, feat, [2] * g.batch_size)
                    .transpose(-1, -2)
                    .unsqueeze(0)
                )
                .squeeze(0)
                .transpose(-1, -2)
            )
        return self.cls(self.readout(g_arr[-1], feat))


if args.gpu == -1:
    device = torch.device("cpu")
else:
    device = torch.device(args.gpu)

if args.model == "chebnet":
    model = ChebNet(2, 1, [32, 64, 128, 256], 10)
else:
    model = MoNet(10, 1, [32, 64, 128, 256], 10)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
log_interval = 50

for epoch in range(10):
    print("epoch {} starts".format(epoch))
    model.train()
    hit, tot = 0, 0
    loss_accum = 0
    for i, (g, x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        g = [g_i.to(device) for g_i in g]
        out = model(g, x)
        hit += (out.max(-1)[1] == y).sum().item()
        tot += len(y)
        loss = F.nll_loss(out, y)
        loss_accum += loss.item()

        if (i + 1) % log_interval == 0:
            print(
                "loss: {}, acc: {}".format(loss_accum / log_interval, hit / tot)
            )
            hit, tot = 0, 0
            loss_accum = 0

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    hit, tot = 0, 0
    for g, x, y in test_loader:
        x = x.to(device)
        y = y.to(device)
        g = [g_i.to(device) for g_i in g]
        out = model(g, x)
        hit += (out.max(-1)[1] == y).sum().item()
        tot += len(y)

    print("test acc: ", hit / tot)
