import os
import random
import time
import argparse
import datetime
import numpy as np
import subprocess

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from timm.utils import AverageMeter

from config import get_config
from models import build_model
from data_loader import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from models.frequency_loss import FrequencyLoss
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_checkpoint_with_name, save_checkpoint_with_name


try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None



def parse_option():
    parser = argparse.ArgumentParser('MFM pre-training script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')

    # distributed training
    parser.add_argument("--distributed", action="store_true", help="Using distributed")
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--port', type=int, default=29500, help='port only works when launcher=="slurm"')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training')

    args = parser.parse_args()
    args.studnet_normal_strategy = True

    config = get_config(args)
    

    return args, config


def get_model_info(model, name):
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"{name}: number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"{name}: number of GFLOPs: {flops / 1e9}")
    return model_without_ddp
    

def main(config, log_writer):
    data_loader_train = build_loader(config, logger, is_pretrain=True)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    student_model, mfm_header, distillation_header = build_model(config, is_pretrain=True, is_student=True)
    teacher_model, _, _ = build_model(config, is_pretrain=True, is_student=False)

    student_model.cuda()
    mfm_header.cuda()
    distillation_header.cuda()
    teacher_model.cuda()

    logger.info("student model is:")
    logger.info(str(student_model))

    logger.info("mfm header is:")
    logger.info(str(mfm_header))

    logger.info("distillation header is:")
    logger.info(str(distillation_header))

    logger.info("teacher model is:")
    logger.info(str(teacher_model))


    student_optimizer = build_optimizer(config, student_model, logger, is_pretrain=True)
    mfm_header_optimizer = build_optimizer(config, mfm_header, logger, is_pretrain=True)
    distillation_header_optimizer = build_optimizer(config, distillation_header, logger, is_pretrain=True)


    criterion_freq = FrequencyLoss(
            loss_gamma=config.MODEL.FREQ_LOSS.LOSS_GAMMA,
            matrix_gamma=config.MODEL.FREQ_LOSS.MATRIX_GAMMA,
            patch_factor=config.MODEL.FREQ_LOSS.PATCH_FACTOR,
            ave_spectrum=config.MODEL.FREQ_LOSS.AVE_SPECTRUM,
            with_matrix=config.MODEL.FREQ_LOSS.WITH_MATRIX,
            log_matrix=config.MODEL.FREQ_LOSS.LOG_MATRIX,
            batch_matrix=config.MODEL.FREQ_LOSS.BATCH_MATRIX).cuda()
    
    criterion_student = nn.CosineSimilarity(dim=1).cuda()
    
    if config.AMP_OPT_LEVEL != "O0":
        student_model, student_optimizer = amp.initialize(student_model, student_optimizer, opt_level=config.AMP_OPT_LEVEL)
        mfm_header, mfm_header_optimizer = amp.initialize(mfm_header, mfm_header_optimizer, opt_level=config.AMP_OPT_LEVEL)
        distillation_header, distillation_header_optimizer = amp.initialize(distillation_header, distillation_header_optimizer, opt_level=config.AMP_OPT_LEVEL)


    student_model = torch.nn.parallel.DistributedDataParallel(student_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
    mfm_header = torch.nn.parallel.DistributedDataParallel(mfm_header, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
    distillation_header = torch.nn.parallel.DistributedDataParallel(distillation_header, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
    teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)

    student_model_without_ddp = get_model_info(student_model, 'Student')
    mfm_header_model_without_ddp = get_model_info(mfm_header, 'MFM Header')
    distillation_header_model_without_ddp = get_model_info(distillation_header, 'Distillation Header')
    teacher_model_without_ddp = get_model_info(teacher_model, 'Teacher')
 

    student_lr_scheduler = build_scheduler(config, student_optimizer, len(data_loader_train))
    mfm_header_lr_scheduler = build_scheduler(config, mfm_header_optimizer, len(data_loader_train))
    distillation_header_lr_scheduler = build_scheduler(config, distillation_header_optimizer, len(data_loader_train))
    momentum_schedule = cosine_scheduler(config.TRAIN.MOMENTUM_TEACHER, 1, config.TRAIN.EPOCHS, len(data_loader_train))

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT, logger)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

    if config.MODEL.RESUME:
        load_checkpoint_with_name(config, student_model_without_ddp, 'model', student_optimizer, 'student_optimizer', 
                                  student_lr_scheduler, 'student_lr_scheduler', logger)
        load_checkpoint_with_name(config, mfm_header_model_without_ddp, 'mfm_header_model', mfm_header_optimizer, 'mfm_header_optimizer', 
                                  mfm_header_lr_scheduler, 'mfm_header_lr_scheduler', logger)
        load_checkpoint_with_name(config, distillation_header_model_without_ddp, 'distillation_header_model', distillation_header_optimizer, 'distillation_header_optimizer', 
                                  distillation_header_lr_scheduler, 'distillation_header_lr_scheduler', logger)
        load_checkpoint_with_name(config, teacher_model_without_ddp, 'teacher_model', None, 'None_optimizer', None,
                                  'None_lr_scheduler', logger)
        

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, student_model, mfm_header, distillation_header, teacher_model,
                        data_loader_train, student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                        epoch, student_lr_scheduler, mfm_header_lr_scheduler, distillation_header_lr_scheduler, 
                        criterion_freq, criterion_student, momentum_schedule, log_writer)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint_with_name(config, epoch, student_model_without_ddp, mfm_header_model_without_ddp,
                                      distillation_header_model_without_ddp, teacher_model_without_ddp,
                                      0., student_optimizer, mfm_header_optimizer, distillation_header_optimizer,
                                      student_lr_scheduler, mfm_header_lr_scheduler, distillation_header_lr_scheduler, logger)
        if dist.get_rank() == 0:
            if log_writer is not None:
                log_writer.flush()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))



def get_norm(student_optimizer, mfm_header_optimizer, distillation_header_optimizer,
             student_model, mfm_header, distillation_header):
    if config.AMP_OPT_LEVEL != "O0":
        if config.TRAIN.CLIP_GRAD:
            student_grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(student_optimizer), config.TRAIN.CLIP_GRAD)
            mfm_grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(mfm_header_optimizer), config.TRAIN.CLIP_GRAD)
            distillation_grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(distillation_header_optimizer), config.TRAIN.CLIP_GRAD)
        else:
            student_grad_norm = get_grad_norm(amp.master_params(student_optimizer))
            mfm_grad_norm = get_grad_norm(amp.master_params(mfm_header_optimizer))
            distillation_grad_norm = get_grad_norm(amp.master_params(distillation_header_optimizer))
    else:
        if config.TRAIN.CLIP_GRAD:
            student_grad_norm = torch.nn.utils.clip_grad_norm_(student_model.parameters(), config.TRAIN.CLIP_GRAD)
            mfm_grad_norm = torch.nn.utils.clip_grad_norm_(mfm_header.parameters(), config.TRAIN.CLIP_GRAD)
            distillation_grad_norm = torch.nn.utils.clip_grad_norm_(distillation_header.parameters(), config.TRAIN.CLIP_GRAD)
        else:
            student_grad_norm = get_grad_norm(student_model.parameters())
            mfm_grad_norm = get_grad_norm(mfm_header.parameters())
            distillation_grad_norm = get_grad_norm(distillation_header.parameters())

    return student_grad_norm, mfm_grad_norm, distillation_grad_norm


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False


def Unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True


def train_one_epoch(config, student_model, mfm_header, distillation_header, teacher_model, 
                    data_loader, student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                    epoch, student_lr_scheduler, mfm_header_lr_scheduler, distillation_header_lr_scheduler, 
                    criterion_freq, criterion_student, momentum_schedule, log_writer):
    student_model.train()
    mfm_header.train()
    distillation_header.train()
    freeze_model(teacher_model)
    # for p in teacher_model.parameters():
    #     p.requires_grad = False

    student_optimizer.zero_grad()
    mfm_header_optimizer.zero_grad()
    distillation_header_optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    mfm_loss_meter = AverageMeter()
    distillation_loss_meter = AverageMeter()
    student_loss_meter = AverageMeter()

    mfm_norm_meter = AverageMeter()
    distillation_norm_meter = AverageMeter()
    student_norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (img, img_lq, mask, _) in enumerate(data_loader):
        img = img.cuda(non_blocking=True)
        if img_lq is not None:
            img_lq = img_lq.cuda(non_blocking=True)
        if mask is not None:
            mask = mask.cuda(non_blocking=True)

        new_input, student_res = student_model(img, mask)
        _, teacher_res = teacher_model(img, mask)
        mfm_res = mfm_header(student_res)
        distillation_res = distillation_header(student_res)

        mfm_loss = criterion_freq(mfm_res, new_input).mean()
        distillation_loss = (1 - criterion_student(F.normalize(teacher_res, p=2, dim=1),
                                                   F.normalize(distillation_res, p=2, dim=1)).mean())
        student_loss = (config.TRAIN.WEIGHT_MFM_LOSS * mfm_loss + 
                        config.TRAIN.WEIGHT_DISTILLATION_LOSS * distillation_loss) / (config.TRAIN.WEIGHT_MFM_LOSS + config.TRAIN.WEIGHT_DISTILLATION_LOSS)


        # student and headers update
        freeze_model(student_model)
        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(mfm_loss, mfm_header_optimizer) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
                with amp.scale_loss(distillation_loss, distillation_header_optimizer) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
                Unfreeze_model(student_model)
                with amp.scale_loss(student_loss, student_optimizer) as scaled_loss:
                    scaled_loss.backward()
                norms = get_norm(student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                                 student_model, mfm_header, distillation_header)
                student_grad_norm, mfm_grad_norm, distillation_grad_norm = norms

            else:
                mfm_loss.backward(retain_graph=True)
                distillation_loss.backward(retain_graph=True)
                Unfreeze_model(student_model)
                student_loss.backward()
                norms = get_norm(student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                                 student_model, mfm_header, distillation_header)
                student_grad_norm, mfm_grad_norm, distillation_grad_norm = norms

            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                student_optimizer.step()
                student_optimizer.zero_grad()
                student_lr_scheduler.step_update(epoch * num_steps + idx)

                mfm_header_optimizer.step()
                mfm_header_optimizer.zero_grad()
                mfm_header_lr_scheduler.step_update(epoch * num_steps + idx)

                distillation_header_optimizer.step()
                distillation_header_optimizer.zero_grad()
                distillation_header_lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            student_optimizer.zero_grad()
            mfm_header_optimizer.zero_grad()
            distillation_header_optimizer.zero_grad()

            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(mfm_loss, mfm_header_optimizer) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
                with amp.scale_loss(distillation_loss, distillation_header_optimizer) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
                Unfreeze_model(student_model)
                with amp.scale_loss(student_loss, student_optimizer) as scaled_loss:
                    scaled_loss.backward()
                
                norms = get_norm(student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                                 student_model, mfm_header, distillation_header)
                student_grad_norm, mfm_grad_norm, distillation_grad_norm = norms
                
            else:
                mfm_loss.backward(retain_graph=True)
                distillation_loss.backward(retain_graph=True)
                Unfreeze_model(student_model)
                student_loss.backward()
                norms = get_norm(student_optimizer, mfm_header_optimizer, distillation_header_optimizer, 
                                 student_model, mfm_header, distillation_header)
                student_grad_norm, mfm_grad_norm, distillation_grad_norm = norms

            student_optimizer.step()
            student_lr_scheduler.step_update(epoch * num_steps + idx)

            mfm_header_optimizer.step()
            mfm_header_lr_scheduler.step_update(epoch * num_steps + idx)

            distillation_header_optimizer.step()
            distillation_header_lr_scheduler.step_update(epoch * num_steps + idx)

        # EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[idx]  # momentum parameter
            names_q, params_q, names_k, params_k = [], [], [], []
            for name_q, param_q in student_model.module.named_parameters():
                names_q.append(name_q)
                params_q.append(param_q)
            for name_k, param_k in teacher_model.named_parameters():
                names_k.append(name_k)
                params_k.append(param_k)
            names_common = list(set(names_q) & set(names_k))
            params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
            params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]
            for param_q, param_k in zip(params_q, params_k):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
        
        #  logging
        torch.cuda.synchronize()

        mfm_loss_meter.update(mfm_loss.item(), img.size(0))
        distillation_loss_meter.update(distillation_loss.item(), img.size(0))
        student_loss_meter.update(student_loss.item(), img.size(0))

        mfm_norm_meter.update(mfm_grad_norm)
        distillation_norm_meter.update(distillation_grad_norm)
        student_norm_meter.update(student_grad_norm)

        batch_time.update(time.time() - end)
        end = time.time()

        mfm_lr = mfm_header_optimizer.param_groups[0]["lr"]
        distillation_lr = distillation_header_optimizer.param_groups[0]["lr"]
        student_lr = student_optimizer.param_groups[0]["lr"]
        mfm_loss_value_reduce = reduce_tensor(mfm_loss).item()
        distillation_loss_value_reduce = reduce_tensor(distillation_loss).item()
        student_loss_value_reduce = reduce_tensor(student_loss).item()

        if log_writer is not None and (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((idx / num_steps + epoch) * 1000)
            log_writer.add_scalar('student_train_loss', student_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('student_grad_norm', student_grad_norm, epoch_1000x)
            log_writer.add_scalar('student_lr', student_lr, epoch_1000x)

            log_writer.add_scalar('mfm_train_loss', mfm_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('mfm_grad_norm', mfm_grad_norm, epoch_1000x)
            log_writer.add_scalar('mfm_lr', mfm_lr, epoch_1000x)

            log_writer.add_scalar('distillation_train_loss', distillation_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('distillation_grad_norm', distillation_grad_norm, epoch_1000x)
            log_writer.add_scalar('distillation_lr', distillation_lr, epoch_1000x)

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))}\t' 
                f'mfm_lr {mfm_lr:.6f} distillation_lr {distillation_lr:.6f} student_lr {student_lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'mfm_loss {mfm_loss_meter.val:.4f} ({mfm_loss_meter.avg:.4f})\t'
                f'distillation_loss {distillation_loss_meter.val:.4f} ({distillation_loss_meter.avg:.4f})\t'
                f'student_loss {student_loss_meter.val:.4f} ({student_loss_meter.avg:.4f})\t'
                f'mfm_norm {mfm_norm_meter.val:.4f} ({mfm_norm_meter.avg:.4f})\t'
                f'distillation_norm {distillation_norm_meter.val:.4f} ({distillation_norm_meter.avg:.4f})\t'
                f'student_norm {student_norm_meter.val:.4f} ({student_norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')

    epoch_time = time.time() - start
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


if __name__ == '__main__':
    args, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    cuda_version = torch.version.cuda

    # Check for cuDNN version
    cudnn_version = torch.backends.cudnn.version()

    print(f"CUDA Version: {cuda_version}")
    print(f"cuDNN Version: {cudnn_version}")

    ## initialize slurm distributed training environment
    proc_id = int(os.environ['SLURM_PROCID'])
    print(f"proc_id {proc_id}")
    ntasks = int(os.environ['SLURM_NTASKS'])
    print(f"ntasks {ntasks}")
    node_list = os.environ['SLURM_NODELIST']
    print(f"node_list {node_list}")
    num_gpus = torch.cuda.device_count()
    print(f"num_gpus {num_gpus}")
    print(f"torch.cuda.is_available() {torch.cuda.is_available()}")
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if args.port is not None:
        os.environ['MASTER_PORT'] = str(args.port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend='nccl')
    world_size = torch.cuda.device_count()
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    print(f"seed {seed}")
    # random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
    logger.info(f'World Size: {world_size}')
    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")
        log_writer = SummaryWriter(log_dir=config.OUTPUT)
    else:
        log_writer = None

    # print config
    logger.info(config.dump())

    main(config, log_writer)