from opts import get_options
import torch
import os
import pprint as pp
import trainer.trainer
try:
    import neptune.new as neptune
except ImportError:
    pass
import trainer.model
from multiprocessing import Process
from trainer.trainconfig import TrainConfig

def init_model(opts):

    node_dim = 1
    if opts.initialization.startswith('positional'):
        node_dim = opts.embed_dim

    model = trainer.model.Attention(opts.problem, opts.embed_dim, opts.n_encoder_layers, opts.n_heads,
                                    node_dim=node_dim,
                                    decoding_type=opts.decoding_type,
                                    throwback=opts.throwback,
                                    normalize=opts.normalize,
                                    shortcuts=opts.shortcuts)

    return model


def train_task(opts, rank):
    """Trains a model with seed equal to opts.seed[rank] and according to the parameters specified in opts"""
    if rank == 0:
        pp.pprint(vars(opts))

    if (opts.ckpt is not None) and (not os.path.exists(opts.ckpt)):
        os.mkdir(opts.ckpt)

    if opts.train_file == 'None' or not os.path.isfile(opts.train_file):
        opts.train_file = None

    if opts.test_file == 'None' or not os.path.isfile(opts.test_file):
        opts.test_file = None

    if opts.neptune_api_token is not None:
        run = neptune.init(
            project="glukas/Node-Labeling",
            api_token=opts.neptune_api_token,
            source_files=["trainer/model.py", "trainer/dataset.py", "trainer/runner.py", "trainer/opts.py",
                          "trainer/baseline.py"]
        )  # your credentials
        run["parameters/vm_id"] = "main"
        run["parameters/commandline_args"] = vars(opts)
        run["parameters/rank"] = rank

    else:
        run = None

    device = torch.device("cuda:0" if torch.cuda.is_available() and not opts.no_cuda else "cpu")

    torch.manual_seed(opts.seed[rank] * 11 % 38047)
    
    model = init_model(opts)

    if opts.load_model_path:
        model.load_state_dict(torch.load(opts.load_model_path))

    # Batch size can be at most the number of samples per graph size
    batch_size = min(opts.batch_size, opts.train_samples)

    train_config = TrainConfig(n_epochs=opts.n_epochs, checkpoint_dir=opts.ckpt, batch_size=batch_size,
                                    checkpoint_last_epochs=opts.checkpoint_last_epochs, log_step=opts.log_step,
                                    clip_grad_norm=opts.max_grad_norm, lr=opts.lr,
                                    device=device, rng_seed=opts.seed[rank])

    trainer.trainer.train_multi_size(opts.train_graph_nodes, opts.val_graph_nodes, train_config, opts.train_samples,
                                     opts.val_samples, graph_types=opts.graph_types, initialization=opts.initialization,
                                     problem=opts.problem, neptune_run=run, rng_seed=opts.seed[rank], model=model)

    if run is not None:
        run.stop()


def main(opts):
    train_task(opts)


#
def main_parallel_seeds(opts):
    """Trains a model in parallel for each seed in opts.seed"""
    seeds = opts.seed

    processes = []
    for i in range(len(seeds)):
        p = Process(target=train_task, args=(opts, i) )
        processes.append(p)
        p.start()

    for p in processes:
        p.join()


if __name__ == '__main__':
    main_parallel_seeds(get_options())
