import os
import argparse
from shutil import copyfile
import torch.distributed as dist
import torch
import importlib
import datetime
from utils.dist_utils import (
    get_world_size,
)
from omegaconf import OmegaConf
from utils.utils import seed_all

parser = argparse.ArgumentParser(description='VFI')
parser.add_argument('-c', '--config', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('--local_rank', default='0')

args = parser.parse_args()


def main_worker(rank, config):
    if 'local_rank' not in config:
        config['local_rank'] = config['global_rank'] = rank
    if torch.cuda.is_available():
        print(f'Rank {rank} is available')
        config['device'] = f"cuda:{rank}"
        if config['distributed']:
            dist.init_process_group(backend='nccl',
                                    timeout=datetime.timedelta(seconds=5400))
    else:
        config['device'] = 'cpu'

    # cfg_name = os.path.basename(args.config).split('.')[0]
    # config['exp_name'] = cfg_name + '_' + config['exp_name']
    config['save_dir'] = os.path.join(config['save_dir'], config['exp_name'])
    print(config['save_dir'])

    if (not config['distributed']) or rank == 0:
        os.makedirs(config['save_dir'], exist_ok=True)
        os.makedirs(f'{config["save_dir"]}/ckpts', exist_ok=True)
        config_path = os.path.join(config['save_dir'],
                                   args.config.split('/')[-1])
        if not os.path.isfile(config_path):
            copyfile(args.config, config_path)
        print('[**] create folder {}'.format(config['save_dir']))

    trainer_name = config.get('trainer_type', 'base_trainer')
    print(f'using GPU {rank} for training')
    if rank == 0:
        print(trainer_name)
    trainer_pack = importlib.import_module('trainers.' + trainer_name)
    trainer = trainer_pack.Trainer(config)

    trainer.train()


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    cfg = OmegaConf.load(args.config)
    seed_all(cfg.seed)
    rank = int(args.local_rank)
    torch.cuda.set_device(torch.device(f'cuda:{rank}'))
    # setting distributed cfgurations
    cfg['world_size'] = get_world_size()
    cfg['local_rank'] = rank
    if rank == 0:
        print('world_size:', cfg['world_size'])
    main_worker(rank, cfg)
