import torch.nn.parallel.distributed
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import trainer.model
import trainer.trainer
import problems.GC as GC
import problems.MVC as MVC
import numpy as np
import neptune.new as neptune
import random
import torch.autograd

def init_process(rank, size, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)


def train_distributed(rank, world_size, train_n, val_n, train_config, train_samples, val_samples, initialization, problem, graph_types, seed, neptune_api_token):
    PRIME_NUMBER = 181081

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)
    random.seed(seed)

    # Needs to be distinct for each processor
    rng = np.random.default_rng(seed=seed+rank)
    graph_seed = int(rng.integers(low=1, high=PRIME_NUMBER, dtype=np.int_))

    # SHARED
    val_seed = (seed * 658) % PRIME_NUMBER

    init_process(rank, world_size)
    print(
        f"Rank {rank + 1}/{world_size} process initialized. Seeds {graph_seed} and {val_seed}.\n"
    )

    dist.barrier()

    model = trainer.model.make_default_model(problem)

    ddp_model = torch.nn.parallel.distributed.DistributedDataParallel(model, device_ids=None, output_device=None)

    run = None
    train_samples_per_rank = int(train_samples/world_size)
    if rank == 0 and neptune_api_token is not None:
        run = neptune.init(
            project="glukas/Node-Labeling",
            api_token=neptune_api_token,
            source_files=["trainer/model.py", "trainer/dataset.py", "trainer/runner.py", "trainer/opts.py",
                              "trainer/baseline.py"]
        )  # your credentials
        run["parameters/vm_id"] = "distributed"
        run["parameters/seed"] = seed
        run["parameters/train_samples_per_rank"] = train_samples_per_rank
        run["parameters/world_size"] = world_size
    else:
        # Only checkpoint in master
        train_config.checkpoint_dir = None

    with torch.autograd.set_detect_anomaly(True):
        trainer.trainer.train_model_multi_size(ddp_model, train_n=train_n, val_n=val_n, train_config=train_config,
                                           train_samples=train_samples_per_rank, val_samples=val_samples, initialization=initialization,
                                           problem=problem, graph_types=graph_types, train_seed=graph_seed,
                                           val_seed=val_seed, neptune_run=run)


def spawn_distributed_training(world_size, train_n, val_n, train_config, train_samples=6000, val_samples=2000, initialization="degree", problem=GC, graph_types=None, seed=123, neptune_api_token=None):
    mp.spawn(
        train_distributed, args=(world_size, train_n, val_n, train_config, train_samples, val_samples, initialization, problem, graph_types, seed, neptune_api_token),
        nprocs=world_size, join=True
    )