import ANONYMOUStorch.prototype as ANONYMOUS_ABCpt
import torch as th
import torch.multiprocessing as mp
from ANONYMOUS import ANONYMOUS_ABC_instantiate, link_hyd_run, load_class
from ANONYMOUS.logging import Wandb
from ANONYMOUS.random import reset_global_seed
from ANONYMOUStorch.ddp import ddp_utils
from ANONYMOUStorch.trainer import trainer_save_cfg

from datasets import get_dataset
from modules import import_fns


def run(cfg):  # pylint: disable= too-many-locals
    reset_global_seed(cfg.seed)
    if ddp_utils.is_master():
        link_hyd_run()
        Wandb.launch(cfg, cfg.log, True)

    if cfg.is_dist:
        ANONYMOUS_ABCpt.set_gpu_mode(True, cfg.trainer.gpu)
    else:
        from ANONYMOUStorch import set_best_device

        gpu_id = set_best_device(mem_prior=1.0)
        cfg.trainer.gpu = gpu_id

    init_model, loss_fn_wrapper, trainer_register = import_fns(cfg.model)
    model = init_model(cfg.model)

    trainer_str = (
        "trainer.ddp_trainer.Trainer" if cfg.is_dist else "trainer.trainer.Trainer"
    )
    trainer = load_class(trainer_str)(cfg.trainer, loss_fn_wrapper(cfg))

    optimizer = ANONYMOUS_ABC_instantiate(cfg.optimizer.optim, model.parameters())
    trainer.set_model_optim(model, optimizer)

    trainset, valset = get_dataset(cfg.data)

    train_loader, train_sampler, val_loader, val_sampler = ANONYMOUS_ABC_instantiate(
        cfg.data.dataloader,
        trainset,
        valset,
        rank=cfg.trainer.rank,
        world_size=cfg.trainer.world_size,
    )
    if cfg.is_dist:
        trainer.set_sampler(train_sampler, val_sampler)
    trainer.set_dataset(trainset, valset)
    trainer.set_dataloader(train_loader, val_loader)
    trainer_register(trainer, cfg)

    if ddp_utils.is_master():
        trainer_save_cfg(trainer, cfg)
        trainer.set_monitor(cfg.log)
        trainer.save_ckpt()

    trainer.train()

    Wandb.finish()


@ddp_utils.ddp_runner
def mock_run(cfg):
    run(cfg)


def dist_run(cfg):
    world_size = th.cuda.device_count()
    ddp_utils.prepare_cfg(cfg)
    mp.spawn(mock_run, args=(world_size, None, cfg), nprocs=world_size, join=True)
