import argparse
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ogb.graphproppred import Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from preprocessing import prepare_dataset
from torch.utils.data import Dataset
from tqdm import tqdm

import dgl
from dgl.dataloading import GraphDataLoader


def aggregate_mean(h, vector_field, h_in):
    return torch.mean(h, dim=1)


def aggregate_max(h, vector_field, h_in):
    return torch.max(h, dim=1)[0]


def aggregate_sum(h, vector_field, h_in):
    return torch.sum(h, dim=1)


def aggregate_dir_dx(h, vector_field, h_in, eig_idx=1):
    eig_w = (
        (vector_field[:, :, eig_idx])
        / (
            torch.sum(
                torch.abs(vector_field[:, :, eig_idx]), keepdim=True, dim=1
            )
            + 1e-8
        )
    ).unsqueeze(-1)
    h_mod = torch.mul(h, eig_w)
    return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)


class FCLayer(nn.Module):
    def __init__(self, in_size, out_size):
        super(FCLayer, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.linear = nn.Linear(in_size, out_size, bias=True)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.linear.weight, 1 / self.in_size)
        self.linear.bias.data.zero_()

    def forward(self, x):
        h = self.linear(x)
        return h


class MLP(nn.Module):
    def __init__(self, in_size, out_size):
        super(MLP, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.fc = FCLayer(in_size, out_size)

    def forward(self, x):
        x = self.fc(x)
        return x


class DGNLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout, aggregators):
        super().__init__()

        self.dropout = dropout

        self.aggregators = aggregators

        self.batchnorm_h = nn.BatchNorm1d(out_dim)
        self.pretrans = MLP(in_size=2 * in_dim, out_size=in_dim)
        self.posttrans = MLP(
            in_size=(len(aggregators) * 1 + 1) * in_dim, out_size=out_dim
        )

    def pretrans_edges(self, edges):
        z2 = torch.cat([edges.src["h"], edges.dst["h"]], dim=1)
        vector_field = edges.data["eig"]
        return {"e": self.pretrans(z2), "vector_field": vector_field}

    def message_func(self, edges):
        return {
            "e": edges.data["e"],
            "vector_field": edges.data["vector_field"],
        }

    def reduce_func(self, nodes):
        h_in = nodes.data["h"]
        h = nodes.mailbox["e"]

        vector_field = nodes.mailbox["vector_field"]

        h = torch.cat(
            [
                aggregate(h, vector_field, h_in)
                for aggregate in self.aggregators
            ],
            dim=1,
        )

        return {"h": h}

    def forward(self, g, h, snorm_n):

        g.ndata["h"] = h

        # pretransformation
        g.apply_edges(self.pretrans_edges)

        # aggregation
        g.update_all(self.message_func, self.reduce_func)
        h = torch.cat([h, g.ndata["h"]], dim=1)

        # posttransformation
        h = self.posttrans(h)

        # graph and batch normalization
        h = h * snorm_n
        h = self.batchnorm_h(h)
        h = F.relu(h)

        h = F.dropout(h, self.dropout, training=self.training)

        return h


class MLPReadout(nn.Module):
    def __init__(self, input_dim, output_dim, L=2):  # L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [
            nn.Linear(input_dim // 2**l, input_dim // 2 ** (l + 1), bias=True)
            for l in range(L)
        ]
        list_FC_layers.append(
            nn.Linear(input_dim // 2**L, output_dim, bias=True)
        )
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L

    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y


class DGNNet(nn.Module):
    def __init__(self, hidden_dim=420, out_dim=420, dropout=0.2, n_layers=4):
        super().__init__()

        self.embedding_h = AtomEncoder(emb_dim=hidden_dim)
        self.aggregators = [
            aggregate_mean,
            aggregate_sum,
            aggregate_max,
            aggregate_dir_dx,
        ]

        self.layers = nn.ModuleList(
            [
                DGNLayer(
                    in_dim=hidden_dim,
                    out_dim=hidden_dim,
                    dropout=dropout,
                    aggregators=self.aggregators,
                )
                for _ in range(n_layers - 1)
            ]
        )
        self.layers.append(
            DGNLayer(
                in_dim=hidden_dim,
                out_dim=out_dim,
                dropout=dropout,
                aggregators=self.aggregators,
            )
        )

        # 128 out dim since ogbg-molpcba has 128 tasks
        self.MLP_layer = MLPReadout(out_dim, 128)

    def forward(self, g, h, snorm_n):
        h = self.embedding_h(h)

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, snorm_n)
            h = h_t

        g.ndata["h"] = h

        hg = dgl.mean_nodes(g, "h")

        return self.MLP_layer(hg)

    def loss(self, scores, labels):
        is_labeled = labels == labels
        loss = nn.BCEWithLogitsLoss()(
            scores[is_labeled], labels[is_labeled].float()
        )
        return loss


def train_epoch(model, optimizer, device, data_loader):
    model.train()
    epoch_loss = 0
    epoch_train_AP = 0
    list_scores = []
    list_labels = []
    for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
        data_loader
    ):
        batch_graphs = batch_graphs.to(device)
        batch_x = batch_graphs.ndata["feat"]  # num x feat
        batch_snorm_n = batch_snorm_n.to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        batch_scores = model(batch_graphs, batch_x, batch_snorm_n)

        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        list_scores.append(batch_scores)
        list_labels.append(batch_labels)

    epoch_loss /= iter + 1

    evaluator = Evaluator(name="ogbg-molpcba")
    epoch_train_AP = evaluator.eval(
        {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
    )["ap"]

    return epoch_loss, epoch_train_AP


def evaluate_network(model, device, data_loader):
    model.eval()
    epoch_test_loss = 0
    epoch_test_AP = 0
    with torch.no_grad():
        list_scores = []
        list_labels = []
        for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(
            data_loader
        ):
            batch_graphs = batch_graphs.to(device)
            batch_x = batch_graphs.ndata["feat"]
            batch_snorm_n = batch_snorm_n.to(device)
            batch_labels = batch_labels.to(device)

            batch_scores = model(batch_graphs, batch_x, batch_snorm_n)

            loss = model.loss(batch_scores, batch_labels)
            epoch_test_loss += loss.item()
            list_scores.append(batch_scores)
            list_labels.append(batch_labels)

        epoch_test_loss /= iter + 1

        evaluator = Evaluator(name="ogbg-molpcba")
        epoch_test_AP = evaluator.eval(
            {"y_pred": torch.cat(list_scores), "y_true": torch.cat(list_labels)}
        )["ap"]

    return epoch_test_loss, epoch_test_AP


def train(dataset, params):

    trainset, valset, testset = dataset.train, dataset.val, dataset.test
    device = params.device

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))

    model = DGNNet()
    model = model.to(device)

    # view model parameters
    total_param = 0
    print("MODEL DETAILS:\n")
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print("DGN Total parameters:", total_param)

    optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.8, patience=8, verbose=True
    )

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_APs, epoch_val_APs, epoch_test_APs = [], [], []

    train_loader = GraphDataLoader(
        trainset,
        batch_size=params.batch_size,
        shuffle=True,
        collate_fn=dataset.collate,
        pin_memory=True,
    )
    val_loader = GraphDataLoader(
        valset,
        batch_size=params.batch_size,
        shuffle=False,
        collate_fn=dataset.collate,
        pin_memory=True,
    )
    test_loader = GraphDataLoader(
        testset,
        batch_size=params.batch_size,
        shuffle=False,
        collate_fn=dataset.collate,
        pin_memory=True,
    )

    with tqdm(range(450), unit="epoch") as t:
        for epoch in t:
            t.set_description("Epoch %d" % epoch)

            epoch_train_loss, epoch_train_ap = train_epoch(
                model, optimizer, device, train_loader
            )
            epoch_val_loss, epoch_val_ap = evaluate_network(
                model, device, val_loader
            )

            epoch_train_losses.append(epoch_train_loss)
            epoch_val_losses.append(epoch_val_loss)
            epoch_train_APs.append(epoch_train_ap.item())
            epoch_val_APs.append(epoch_val_ap.item())

            _, epoch_test_ap = evaluate_network(model, device, test_loader)

            epoch_test_APs.append(epoch_test_ap.item())

            t.set_postfix(
                train_loss=epoch_train_loss,
                train_AP=epoch_train_ap.item(),
                val_AP=epoch_val_ap.item(),
                refresh=False,
            )

            scheduler.step(-epoch_val_ap.item())

            if optimizer.param_groups[0]["lr"] < 1e-5:
                print("\n!! LR EQUAL TO MIN LR SET.")
                break

            print("")

    best_val_epoch = np.argmax(np.array(epoch_val_APs))
    best_train_epoch = np.argmax(np.array(epoch_train_APs))
    best_val_ap = epoch_val_APs[best_val_epoch]
    best_val_test_ap = epoch_test_APs[best_val_epoch]
    best_val_train_ap = epoch_train_APs[best_val_epoch]
    best_train_ap = epoch_train_APs[best_train_epoch]

    print("Best Train AP: {:.4f}".format(best_train_ap))
    print("Best Val AP: {:.4f}".format(best_val_ap))
    print("Test AP of Best Val: {:.4f}".format(best_val_test_ap))
    print("Train AP of Best Val: {:.4f}".format(best_val_train_ap))


class Subset(object):
    def __init__(self, dataset, labels, indices):
        dataset = [dataset[idx] for idx in indices]
        labels = [labels[idx] for idx in indices]
        self.dataset, self.labels = [], []
        for i, g in enumerate(dataset):
            if g.num_nodes() > 5:
                self.dataset.append(g)
                self.labels.append(labels[i])
        self.len = len(self.dataset)

    def __getitem__(self, item):
        return self.dataset[item], self.labels[item]

    def __len__(self):
        return self.len


class PCBADataset(Dataset):
    def __init__(self, name):
        print("[I] Loading dataset %s..." % (name))
        self.name = name

        self.dataset, self.split_idx = prepare_dataset(name)
        print("One hot encoding substructure counts... ", end="")
        self.d_id = [1] * self.dataset[0].edata["subgraph_counts"].shape[1]

        for g in self.dataset:
            g.edata["eig"] = g.edata["subgraph_counts"].float()

        self.train = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["train"]
        )
        self.val = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["valid"]
        )
        self.test = Subset(
            self.dataset, self.split_idx["label"], self.split_idx["test"]
        )

        print(
            "train, test, val sizes :",
            len(self.train),
            len(self.test),
            len(self.val),
        )
        print("[I] Finished loading.")

    # form a mini batch from a given list of samples = [(graph, label) pairs]
    def collate(self, samples):
        # The input samples is a list of pairs (graph, label).
        graphs, labels = map(list, zip(*samples))
        labels = torch.stack(labels)

        tab_sizes_n = [g.num_nodes() for g in graphs]
        tab_snorm_n = [
            torch.FloatTensor(size, 1).fill_(1.0 / size) for size in tab_sizes_n
        ]
        snorm_n = torch.cat(tab_snorm_n).sqrt()
        batched_graph = dgl.batch(graphs)

        return batched_graph, labels, snorm_n


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--gpu_id", default=0, type=int, help="Please give a value for gpu id"
    )
    parser.add_argument(
        "--seed", default=41, type=int, help="Please give a value for seed"
    )
    parser.add_argument(
        "--batch_size",
        default=2048,
        type=int,
        help="Please give a value for batch_size",
    )
    args = parser.parse_args()

    # device
    args.device = torch.device(
        "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu"
    )

    # setting seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    dataset = PCBADataset("ogbg-molpcba")
    train(dataset, args)
