import copy
from pathlib import Path

import click
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from logzero import logger
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP
from modules.initializers import GlorotOrthogonal
from qm9 import QM9
from ruamel.yaml import YAML
from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader

import dgl
from dgl.data.utils import Subset


def split_dataset(
    dataset, num_train, num_valid, shuffle=False, random_state=None
):
    """Split dataset into training, validation and test set.

    Parameters
    ----------
    dataset
        We assume that ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
        gives the ith datapoint.
    num_train : int
        Number of training datapoints.
    num_valid : int
        Number of validation datapoints.
    shuffle : bool, optional
        By default we perform a consecutive split of the dataset. If True,
        we will first randomly shuffle the dataset.
    random_state : None, int or array_like, optional
        Random seed used to initialize the pseudo-random number generator.
        This can be any integer between 0 and 2^32 - 1 inclusive, an array
        (or other sequence) of such integers, or None (the default value).
        If seed is None, then RandomState will try to read data from /dev/urandom
        (or the Windows analogue) if available or seed from the clock otherwise.

    Returns
    -------
    list of length 3
        Subsets for training, validation and test.
    """
    from itertools import accumulate

    num_data = len(dataset)
    assert num_train + num_valid < num_data
    lengths = [num_train, num_valid, num_data - num_train - num_valid]
    if shuffle:
        indices = np.random.RandomState(seed=random_state).permutation(num_data)
    else:
        indices = np.arange(num_data)
    return [
        Subset(dataset, indices[offset - length : offset])
        for offset, length in zip(accumulate(lengths), lengths)
    ]


@torch.no_grad()
def ema(ema_model, model, decay):
    msd = model.state_dict()
    for k, ema_v in ema_model.state_dict().items():
        model_v = msd[k].detach()
        ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)


def edge_init(edges):
    R_src, R_dst = edges.src["R"], edges.dst["R"]
    dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))
    # d: bond length, o: bond orientation
    return {"d": dist, "o": R_src - R_dst}


def _collate_fn(batch):
    graphs, line_graphs, labels = map(list, zip(*batch))
    g, l_g = dgl.batch(graphs), dgl.batch(line_graphs)
    labels = torch.tensor(labels, dtype=torch.float32)
    return g, l_g, labels


def train(device, model, opt, loss_fn, train_loader):
    model.train()
    epoch_loss = 0
    num_samples = 0

    for g, l_g, labels in train_loader:
        g = g.to(device)
        l_g = l_g.to(device)
        labels = labels.to(device)
        logits = model(g, l_g)
        loss = loss_fn(logits, labels.view([-1, 1]))
        epoch_loss += loss.data.item() * len(labels)
        num_samples += len(labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

    return epoch_loss / num_samples


@torch.no_grad()
def evaluate(device, model, valid_loader):
    model.eval()
    predictions_all, labels_all = [], []

    for g, l_g, labels in valid_loader:
        g = g.to(device)
        l_g = l_g.to(device)
        logits = model(g, l_g)
        labels_all.extend(labels)
        predictions_all.extend(
            logits.view(
                -1,
            )
            .cpu()
            .numpy()
        )

    return np.array(predictions_all), np.array(labels_all)


@click.command()
@click.option(
    "-m",
    "--model-cnf",
    type=click.Path(exists=True),
    help="Path of model config yaml.",
)
def main(model_cnf):
    yaml = YAML(typ="safe")
    model_cnf = yaml.load(Path(model_cnf))
    model_name, model_params, train_params, pretrain_params = (
        model_cnf["name"],
        model_cnf["model"],
        model_cnf["train"],
        model_cnf["pretrain"],
    )
    logger.info(f"Model name: {model_name}")
    logger.info(f"Model params: {model_params}")
    logger.info(f"Train params: {train_params}")

    if model_params["targets"] in ["mu", "homo", "lumo", "gap", "zpve"]:
        model_params["output_init"] = nn.init.zeros_
    else:
        # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
        model_params["output_init"] = GlorotOrthogonal

    logger.info("Loading Data Set")
    dataset = QM9(label_keys=model_params["targets"], edge_funcs=[edge_init])

    # data split
    train_data, valid_data, test_data = split_dataset(
        dataset,
        num_train=train_params["num_train"],
        num_valid=train_params["num_valid"],
        shuffle=True,
        random_state=train_params["data_seed"],
    )
    logger.info(f"Size of Training Set: {len(train_data)}")
    logger.info(f"Size of Validation Set: {len(valid_data)}")
    logger.info(f"Size of Test Set: {len(test_data)}")

    # data loader
    train_loader = DataLoader(
        train_data,
        batch_size=train_params["batch_size"],
        shuffle=True,
        collate_fn=_collate_fn,
        num_workers=train_params["num_workers"],
    )

    valid_loader = DataLoader(
        valid_data,
        batch_size=train_params["batch_size"],
        shuffle=False,
        collate_fn=_collate_fn,
        num_workers=train_params["num_workers"],
    )

    test_loader = DataLoader(
        test_data,
        batch_size=train_params["batch_size"],
        shuffle=False,
        collate_fn=_collate_fn,
        num_workers=train_params["num_workers"],
    )

    # check cuda
    gpu = train_params["gpu"]
    device = f"cuda:{gpu}" if gpu >= 0 and torch.cuda.is_available() else "cpu"

    # model initialization
    logger.info("Loading Model")
    if model_name == "dimenet":
        model = DimeNet(
            emb_size=model_params["emb_size"],
            num_blocks=model_params["num_blocks"],
            num_bilinear=model_params["num_bilinear"],
            num_spherical=model_params["num_spherical"],
            num_radial=model_params["num_radial"],
            cutoff=model_params["cutoff"],
            envelope_exponent=model_params["envelope_exponent"],
            num_before_skip=model_params["num_before_skip"],
            num_after_skip=model_params["num_after_skip"],
            num_dense_output=model_params["num_dense_output"],
            num_targets=len(model_params["targets"]),
            output_init=model_params["output_init"],
        ).to(device)
    elif model_name == "dimenet++":
        model = DimeNetPP(
            emb_size=model_params["emb_size"],
            out_emb_size=model_params["out_emb_size"],
            int_emb_size=model_params["int_emb_size"],
            basis_emb_size=model_params["basis_emb_size"],
            num_blocks=model_params["num_blocks"],
            num_spherical=model_params["num_spherical"],
            num_radial=model_params["num_radial"],
            cutoff=model_params["cutoff"],
            envelope_exponent=model_params["envelope_exponent"],
            num_before_skip=model_params["num_before_skip"],
            num_after_skip=model_params["num_after_skip"],
            num_dense_output=model_params["num_dense_output"],
            num_targets=len(model_params["targets"]),
            extensive=model_params["extensive"],
            output_init=model_params["output_init"],
        ).to(device)
    else:
        raise ValueError(f"Invalid Model Name {model_name}")

    if pretrain_params["flag"]:
        torch_path = pretrain_params["path"]
        target = model_params["targets"][0]
        model.load_state_dict(torch.load(f"{torch_path}/{target}.pt"))

        logger.info("Testing with Pretrained model")
        predictions, labels = evaluate(device, model, test_loader)
        test_mae = mean_absolute_error(labels, predictions)
        logger.info(f"Test MAE {test_mae:.4f}")

        return
    # define loss function and optimization
    loss_fn = nn.L1Loss()
    opt = optim.Adam(
        model.parameters(),
        lr=train_params["lr"],
        weight_decay=train_params["weight_decay"],
        amsgrad=True,
    )
    scheduler = optim.lr_scheduler.StepLR(
        opt, train_params["step_size"], gamma=train_params["gamma"]
    )

    # model training
    best_mae = 1e9
    no_improvement = 0

    # EMA for valid and test
    logger.info("EMA Init")
    ema_model = copy.deepcopy(model)
    for p in ema_model.parameters():
        p.requires_grad_(False)

    best_model = copy.deepcopy(ema_model)

    logger.info("Training")
    for i in range(train_params["epochs"]):
        train_loss = train(device, model, opt, loss_fn, train_loader)
        ema(ema_model, model, train_params["ema_decay"])
        if i % train_params["interval"] == 0:
            predictions, labels = evaluate(device, ema_model, valid_loader)

            valid_mae = mean_absolute_error(labels, predictions)
            logger.info(
                f"Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}"
            )

            if valid_mae > best_mae:
                no_improvement += 1
                if no_improvement == train_params["early_stopping"]:
                    logger.info("Early stop.")
                    break
            else:
                no_improvement = 0
                best_mae = valid_mae
                best_model = copy.deepcopy(ema_model)
        else:
            logger.info(f"Epoch {i} | Train Loss {train_loss:.4f}")

        scheduler.step()

    logger.info("Testing")
    predictions, labels = evaluate(device, best_model, test_loader)
    test_mae = mean_absolute_error(labels, predictions)
    logger.info("Test MAE {:.4f}".format(test_mae))


if __name__ == "__main__":
    main()
