import torch
import argparse
import os
import numpy as np
import torch.multiprocessing as mp


def _find_free_port():
    """ Find free port, so multiple runs don't clash """
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # Binding to port 0 will cause the OS to find an available port for us
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    # NOTE: there is still a chance the port could be taken by other processes.
    return port


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Run parameters
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size. Does not scale with number of gpus.')
    parser.add_argument('--lr', type=float, default=5e-4,
                        help='learning rate')
    parser.add_argument('--optimizer', type=str, default="adam",
                        help='optimizer')
    parser.add_argument('--weight_decay', type=float, default=1e-8,
                        help='weight decay')
    parser.add_argument('--print', type=int, default=100,
                        help='print interval')
    parser.add_argument('--log', type=bool, default=False,
                        help='logging flag')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='Num workers in dataloader')

    # Data parameters
    parser.add_argument('--dataset', type=str, default="qm9",
                        help='Data set')
    parser.add_argument('--root', type=str, default="datasets",
                        help='Data set location')
    parser.add_argument('--download', type=bool, default=False,
                        help='Download flag')
    parser.add_argument('--radius', type=float, default=2,
                        help='Radius (Angstrom) between which atoms to add links.')
    parser.add_argument('--energy_units', type=str, default="meV",
                        help='Convert energy from Hartree to eV/meV')
    parser.add_argument('--feature_type', type=str, default="one_hot",
                        help='Type of input feature: one-hot, or Cormorants charge thingy')

    # Model parameters
    parser.add_argument('--model', type=str, default="segnn",
                        help='Model name')
    parser.add_argument('--target', type=str, default="alpha",
                        help='Model name')
    parser.add_argument('--hidden_features', type=int, default=128,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_h', type=int, default=2,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_pos', type=int, default=3,
                        help='max degree of rel pos embedding')
    parser.add_argument('--layers', type=int, default=7,
                        help='Number of message passing layers')
    parser.add_argument('--scheduler', type=str, default="step",
                        help='Learning rate scheduler')
    parser.add_argument('--norm', type=str, default="instance",
                        help='Normalisation type [instance, batch]')
    parser.add_argument('--pool', type=str, default="avg",
                        help='Pooling type type [avg, sum]')
    parser.add_argument('--edge_inference', type=bool, default=False,
                        help='Edge inference flag')

    # Parallel computing stuff
    parser.add_argument('-g', '--gpus', default=2, type=int,
                        help='number of gpus to use (assumes all are on one node)')

    args = parser.parse_args()

    # Select dataset.
    if args.dataset == "qm9":
        from train_qm9 import train
        if args.feature_type == "one_hot":
            in_features = 5
        elif args.feature_type == "cormorant":
            in_features = 15
        elif args.feature_type == "gilmer":
            in_features = 11
        out_features = 1
    else:
        raise(ValueError("Dataset could not be found"))

    # Select model
    if args.model == "segnn":
        from models.segnn.segnn import SEGNNModel
        model = SEGNNModel(in_features,
                           out_features,
                           hidden_features=args.hidden_features,
                           N=args.layers,
                           lmax_h=args.lmax_h,
                           lmax_pos=args.lmax_pos,
                           norm=args.norm,
                           pool=args.pool,
                           edge_inference=args.edge_inference)
    else:
        raise(ValueError("Model could not be found"))

    args.ID = "_".join([args.model, args.target, str(np.random.randint(1e4, 1e5))])

    print(model)
    if args.gpus == 0:
        print('Starting training on the cpu...')
        args.mode = 'cpu'
        train(0, model, args)
    elif args.gpus == 1:
        print('Starting training on a single gpu...')
        args.mode = 'gpu'
        train(0, model, args)
    elif args.gpus > 1:
        print('Starting training on', args.gpus, 'gpus...')
        args.mode = 'gpu'
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        port = _find_free_port()
        print('found free port', port)
        os.environ['MASTER_PORT'] = str(port)
        mp.spawn(train, nprocs=args.gpus, args=(model, args,))
