import argparse
import os
import torch
import torch.utils.benchmark as benchmark
import wandb
import numpy as np
from tqdm import tqdm

from data_qm9 import QM9
import utils


def apply_to_qm9(model, data):
    """ Apply model to qm9"""
    out = model(data)
    return out


trials = 100

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Run parameters
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size. Does not scale with number of gpus.')
    parser.add_argument('--lr', type=float, default=5e-4,
                        help='learning rate')
    parser.add_argument('--optimizer', type=str, default="adam",
                        help='optimizer')
    parser.add_argument('--warmup', type=int, default=0,
                        help='number of warmup epochs')
    parser.add_argument('--weight_decay', type=float, default=1e-8,
                        help='weight decay')
    parser.add_argument('--print', type=int, default=100,
                        help='print interval')
    parser.add_argument('--log', type=bool, default=False,
                        help='logging flag')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='Num workers in dataloader')

    # Data parameters
    parser.add_argument('--dataset', type=str, default="qm9",
                        help='Data set')
    parser.add_argument('--download', type=bool, default=False,
                        help='Download flag')
    parser.add_argument('--radius', type=float, default=2,
                        help='Radius (Angstrom) between which atoms to add links.')
    parser.add_argument('--energy_units', type=str, default="meV",
                        help='Convert energy from Hartree to eV/meV')
    parser.add_argument('--feature_type', type=str, default="one_hot",
                        help='Type of input feature: one-hot, or Cormorants charge thingy')

    # Model parameters
    parser.add_argument('--model', type=str, default="segnn",
                        help='Model name')
    parser.add_argument('--target', type=str, default="alpha",
                        help='Model name')
    parser.add_argument('--hidden_features', type=int, default=128,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_h', type=int, default=2,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_pos', type=int, default=3,
                        help='max degree of rel pos embedding')
    parser.add_argument('--layers', type=int, default=7,
                        help='Number of message passing layers')
    parser.add_argument('--scheduler', type=str, default="step",
                        help='Learning rate scheduler')
    parser.add_argument('--norm', type=str, default="instance",
                        help='Normalisation type [instance, batch]')
    parser.add_argument('--pool', type=str, default="avg",
                        help='Pooling type type [avg, sum]')
    parser.add_argument('--edge_inference', type=bool, default=False,
                        help='Edge inference flag')

    # Parallel computing stuff
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus to use (assumes all are on one node)')

    args = parser.parse_args()

    # Select dataset.
    if args.dataset == "qm9":
        from train_qm9 import train
        if args.feature_type == "one_hot":
            in_features = 5
        elif args.feature_type == "cormorant":
            in_features = 15
        out_features = 1
    else:
        raise(ValueError("Dataset could not be found"))

    assert (args.gpus < 2), "Only single gpu is supported"
    gpu = 0
    # Select model
    if args.model == "segnn":
        from models.segnn.segnn import SEGNNModel
        model = SEGNNModel(in_features,
                           out_features,
                           hidden_features=args.hidden_features,
                           N=args.layers,
                           lmax_h=args.lmax_h,
                           lmax_pos=args.lmax_pos,
                           norm=args.norm,
                           pool=args.pool,
                           edge_inference=args.edge_inference)

    if args.gpus == 0:
        device = 'cpu'
    else:
        device = 'cuda:' + str(gpu)
        if args.gpus > 1:
            dist.init_process_group("nccl", rank=gpu, world_size=args.gpus)
            torch.cuda.set_device(gpu)
    args.ID = "_".join([args.model, args.target, str(np.random.randint(1e4, 1e5))])
    model.to(device)

    dataset = QM9(root="datasets", target=args.target, radius=args.radius,
                  split="train", feature_type=args.feature_type)
    dataloader = utils.make_dataloader(dataset, args.batch_size, args.num_workers, args.gpus, gpu, train=False)

    if args.log:
        wandb.init(project="SEGNN " + args.dataset + " benchmark", name=args.ID, config=args, entity="segnn")

    print("warming up")
    for i, data in tqdm(enumerate(dataloader), total=100):
        if i == 100:
            break
        model(data.to(device))

    num_threads = torch.get_num_threads()
    T = []
    for i, data in enumerate(dataloader):
        if i == trials:
            break
        data = data.to(device)

        print(data.y.shape)

        t0 = benchmark.Timer(
            stmt='apply_to_qm9(model, data)',
            setup='from __main__ import apply_to_qm9',
            num_threads=num_threads,
            globals={'model': model, 'data': data})

        report = t0.timeit(3)
        print(report)
        T.append(report.mean)

    mean = np.mean(T)
    std = np.std(T)
    print("mean, std:", mean, std)
    if args.log:
        wandb.log({"time": mean, "std": std})
