import argparse

import numpy as np
import torch as th
import torch.nn.functional as F
from model import InfoGraphS

import dgl
from dgl.data import QM9EdgeDataset
from dgl.data.utils import Subset
from dgl.dataloading import GraphDataLoader


def argument():
    parser = argparse.ArgumentParser(description="InfoGraphS")

    # data source params
    parser.add_argument(
        "--target", type=str, default="mu", help="Choose regression task"
    )
    parser.add_argument(
        "--train_num", type=int, default=5000, help="Size of training set"
    )

    # training params
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="Training epochs."
    )
    parser.add_argument(
        "--batch_size", type=int, default=20, help="Training batch size."
    )
    parser.add_argument(
        "--val_batch_size", type=int, default=100, help="Validation batch size."
    )

    parser.add_argument(
        "--lr", type=float, default=0.001, help="Learning rate."
    )
    parser.add_argument("--wd", type=float, default=0, help="Weight decay.")

    # model params
    parser.add_argument(
        "--hid_dim", type=int, default=64, help="Hidden layer dimensionality"
    )
    parser.add_argument(
        "--reg", type=float, default=0.001, help="Regularization coefficient"
    )

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
        args.device = "cuda:{}".format(args.gpu)
    else:
        args.device = "cpu"

    return args


class DenseQM9EdgeDataset(QM9EdgeDataset):
    def __getitem__(self, idx):
        r"""Get graph and label by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        dgl.DGLGraph
           The graph contains:

           - ``ndata['pos']``: the coordinates of each atom
           - ``ndata['attr']``: the features of each atom
           - ``edata['edge_attr']``: the features of each bond

        Tensor
            Property values of molecular graphs
        """

        pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
        src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
        dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]

        g = dgl.graph((src, dst))

        g.ndata["pos"] = th.tensor(pos).float()
        g.ndata["attr"] = th.tensor(
            self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
        ).float()
        g.edata["edge_attr"] = th.tensor(
            self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
        ).float()

        label = th.tensor(self.targets[idx][self.label_keys]).float()

        n_nodes = g.num_nodes()
        row = th.arange(n_nodes)
        col = th.arange(n_nodes)

        row = row.view(-1, 1).repeat(1, n_nodes).view(-1)
        col = col.repeat(n_nodes)

        src = g.edges()[0]
        dst = g.edges()[1]

        idx = src * n_nodes + dst
        size = list(g.edata["edge_attr"].size())
        size[0] = n_nodes * n_nodes
        edge_attr = g.edata["edge_attr"].new_zeros(size)

        edge_attr[idx] = g.edata["edge_attr"]

        pos = g.ndata["pos"]
        dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)

        new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim=-1)

        graph = dgl.graph((row, col))
        graph.ndata["attr"] = g.ndata["attr"]
        graph.edata["edge_attr"] = new_edge_attr
        graph = graph.remove_self_loop()

        return graph, label


def collate(samples):
    """collate function for building graph dataloader"""

    # generate batched graphs and labels
    graphs, targets = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    batched_targets = th.Tensor(targets)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata["graph_id"] = graph_id

    return batched_graph, batched_targets


def evaluate(model, loader, num, device):
    error = 0
    for graphs, targets in loader:
        graphs = graphs.to(device)

        nfeat, efeat = graphs.ndata["attr"], graphs.edata["edge_attr"]
        targets = targets.to(device)
        error += (model(graphs, nfeat, efeat) - targets).abs().sum().item()

    error = error / num

    return error


if __name__ == "__main__":

    # Step 1: Prepare graph data   ===================================== #
    args = argument()
    label_keys = [args.target]
    print(args)

    dataset = DenseQM9EdgeDataset(label_keys=label_keys)

    # Train/Val/Test Splitting
    N = dataset.targets.shape[0]
    all_idx = np.arange(N)
    np.random.shuffle(all_idx)

    val_num = 10000
    test_num = 10000

    val_idx = all_idx[:val_num]
    test_idx = all_idx[val_num : val_num + test_num]
    train_idx = all_idx[
        val_num + test_num : val_num + test_num + args.train_num
    ]

    train_data = Subset(dataset, train_idx)
    val_data = Subset(dataset, val_idx)
    test_data = Subset(dataset, test_idx)

    unsup_idx = all_idx[val_num + test_num :]
    unsup_data = Subset(dataset, unsup_idx)

    # generate supervised training dataloader and unsupervised training dataloader
    train_loader = GraphDataLoader(
        train_data,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    unsup_loader = GraphDataLoader(
        unsup_data,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    # generate validation & testing dataloader
    val_loader = GraphDataLoader(
        val_data,
        batch_size=args.val_batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    test_loader = GraphDataLoader(
        test_data,
        batch_size=args.val_batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    print("======== target = {} ========".format(args.target))

    in_dim = dataset[0][0].ndata["attr"].shape[1]

    # Step 2: Create model =================================================================== #
    model = InfoGraphS(in_dim, args.hid_dim)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.wd
    )
    scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.7, patience=5, min_lr=0.000001
    )

    # Step 4: training epochs =============================================================== #
    best_val_error = float("inf")
    test_error = float("inf")

    for epoch in range(args.epochs):
        """Training"""
        model.train()
        lr = scheduler.optimizer.param_groups[0]["lr"]

        iteration = 0
        sup_loss_all = 0
        unsup_loss_all = 0
        consis_loss_all = 0

        for sup_data, unsup_data in zip(train_loader, unsup_loader):
            sup_graph, sup_target = sup_data
            unsup_graph, _ = unsup_data

            sup_graph = sup_graph.to(args.device)
            unsup_graph = unsup_graph.to(args.device)

            sup_nfeat, sup_efeat = (
                sup_graph.ndata["attr"],
                sup_graph.edata["edge_attr"],
            )
            unsup_nfeat, unsup_efeat, unsup_graph_id = (
                unsup_graph.ndata["attr"],
                unsup_graph.edata["edge_attr"],
                unsup_graph.ndata["graph_id"],
            )

            sup_target = sup_target
            sup_target = sup_target.to(args.device)

            optimizer.zero_grad()

            sup_loss = F.mse_loss(
                model(sup_graph, sup_nfeat, sup_efeat), sup_target
            )
            unsup_loss, consis_loss = model.unsup_forward(
                unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id
            )

            loss = sup_loss + unsup_loss + args.reg * consis_loss

            loss.backward()

            sup_loss_all += sup_loss.item()
            unsup_loss_all += unsup_loss.item()
            consis_loss_all += consis_loss.item()

            optimizer.step()

        print(
            "Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}".format(
                epoch, sup_loss_all, unsup_loss_all, consis_loss_all
            )
        )

        model.eval()

        val_error = evaluate(model, val_loader, val_num, args.device)
        scheduler.step(val_error)

        if val_error < best_val_error:
            best_val_error = val_error
            test_error = evaluate(model, test_loader, test_num, args.device)

        print(
            "Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}".format(
                epoch, lr, val_error, test_error
            )
        )
