import datetime
import os
import sys
import time

import random

import torch
import torch.distributed as dist
import torchvision

from warmup_scheduler_local.scheduler import GradualWarmupScheduler
import utils
from utils import print_or_log, dist_collect_other

try:
    from apex import amp
except ImportError:
    amp = None


def get_loss(q, k, criterion, noise_batch, t=0.07, device='cuda'):
    N, C = q.shape
    l_pos = torch.bmm(q.view(N, 1, C), k.view(N, C, 1)).view(N, 1)  # positive logit N x 1
    l_neg = torch.mm(q.view(N, C), noise_batch.transpose(0, 1))  # negative logit N x K
    labels = torch.zeros(N, dtype=torch.long).to(device)  # positives are the 0-th
    logits = torch.cat([l_pos, l_neg], dim=1) / t
    prob = torch.mean((logits[:, 0] == logits.max(1)[0]).float()) * 100
    loss = criterion(logits, labels)
    return loss, prob

def compute_feats(model, hyp, video1, audio1, video2=None, audio2=None, feats1=None):
    feat_v, feat_a = model(video1, audio1) if feats1 is None else feats1
    # Perform transfromations
    if video2 is None:
        return feat_v, feat_a
    if hyp[1:] == 'time': # args.arrowtime != 0: # arrow of time: get reversed audio, video
        feat_vT, feat_aT = model(video1.flip(2), audio1.flip(-1))
    if hyp[1:] == 'asynced': # args.asynced != 0: # synchronization: get shifted audio, video
        feat_vT, feat_aT = model(video2, audio2)
    return (feat_v, feat_a), (feat_vT, feat_aT)

def get_pos_neg(hyp, feats1, feats2=None, head=-1, other=True):
    feat_v, feat_a = feats1
    if feats2 is not None:
        feat_vT, feat_aT = feats2
    # Bool for transformations
    if hyp == 'base': # args.arrowtime == 0 and args.asynced == 0:
        transf = "basecase"
    elif hyp[0] == 'i': # args.arrowtime == 1 or args.asynced == 1:
        transf = "invariant"
    elif hyp[0] == 'v': # args.arrowtime == -1 or args.asynced == -1:
        transf = "variant"

    if head != -1:
        feat_v, feat_a = feat_v[head], feat_a[head]
        if transf != 'basecase':
            feat_vT, feat_aT = feat_vT[head], feat_aT[head]

    # Get keys
    with torch.no_grad():
        # Get POS keys: keys are cross-modal (video, audio) and (audio, video)
        if transf == "invariant":
            # positive is transformed of other-modality input i.e. (v, aT), (a, vT), (vT, a), (aT, v)
            feat_v_pos, feat_a_pos = feat_aT.detach(), feat_vT.detach()
            feat_vT_pos, feat_aT_pos = feat_a.detach(), feat_v.detach()
        elif transf == "variant":
            # positive is just non-transformed other modaltiy input i.e. (v, a), (a, v), (vT, aT), (aT, vT)
            feat_v_pos, feat_a_pos = feat_a.detach(), feat_v.detach()
            feat_vT_pos, feat_aT_pos = feat_aT.detach(), feat_vT.detach()
        elif transf == "basecase":
            feat_v_pos, feat_a_pos = feat_a.detach(), feat_v.detach()

        # NEG keys: keys are cross-modal (video, audio) and (audio, video)
        # Get keys for un-transformed input -> keys are other videos in batch
        if other:
            feat_a_neg = dist_collect_other(feat_a_pos, return_before_cat=True)  # all other videos in batch -> X-modal
            feat_v_neg = dist_collect_other(feat_v_pos, return_before_cat=True)  # all other audios in batch -> X-modal
        else:
            feat_a_neg = []
            feat_v_neg = []

        if transf in ["invariant", "variant"]:
            # Get keys for transformed input -> keys are untransformed X-modalities from other videos in batch
            if other:
                feat_aT_neg = dist_collect_other(feat_aT_pos, return_before_cat=True)  # all other videos in batch -> X-modal
                feat_vT_neg = dist_collect_other(feat_vT_pos, return_before_cat=True)  # all other audios in batch -> X-modal
            else:
                feat_aT_neg = []
                feat_vT_neg = []

            if transf == "variant": # add additional hard negative
                feat_aT_neg += [feat_a_pos]  # hard-negative is transformed video -> X-modal
                feat_vT_neg += [feat_v_pos]  # hard-negative is transformed audio -> X-modal
            feat_aT_neg = torch.cat(feat_aT_neg, dim=0)
            feat_vT_neg = torch.cat(feat_vT_neg, dim=0)

            if transf == "variant": # add additional hard negative
                feat_a_neg += [feat_aT_pos]  # plus hard-negative
                feat_v_neg += [feat_vT_pos]  # plus hard-negative
        feat_a_neg = torch.cat(feat_a_neg, dim=0)
        feat_v_neg = torch.cat(feat_v_neg, dim=0)

        # Get a subset of negatives to compare to
        if args.num_negatives != -1:
            feat_a_neg, feat_v_neg = utils.reduce_negatives(feat_a_neg, feat_v_neg, args.num_negatives)
            feat_aT_neg, feat_vT_neg = utils.reduce_negatives(feat_aT_neg, feat_vT_neg, args.num_negatives)

        pairs1 = [feat_v, feat_a, feat_v_pos, feat_a_pos, feat_v_neg, feat_a_neg]

        if transf in ["invariant", "variant"]:
            pairs2 = [feat_vT, feat_aT, feat_vT_pos, feat_aT_pos, feat_vT_neg, feat_aT_neg]
        else: # basecase
            pairs2 = None

    return pairs1,pairs2

def get_losses(pairs1, pairs2, crit1, crit2, hyp='basecase'):
    video_loss1, prob_vid1 = get_loss(
        pairs1[0], # v_i
        pairs1[2], # a_i
        crit1,
        pairs1[4], # Ba_j (and maybe hard-neg)
        t=args.nce_t,
    )
    audio_loss1, prob_aud1 = get_loss(
        pairs1[1], # a_i
        pairs1[3], # v_i
        crit2,
        pairs1[5], # Bv_j (and maybe hard-neg)
        t=args.nce_t,
    )
    loss = 0.5 * video_loss1 + 0.5 * audio_loss1
    if pairs2:
        video_loss2, prob_vid2 = get_loss(
            pairs2[0],  # Tv_i
            pairs2[2],  # Ta_i
            crit1,
            pairs2[4],  # TBa_j (and maybe hard-neg)
            t=args.nce_t,
        )
        audio_loss2, prob_aud2 = get_loss(
            pairs2[1],  # Ta_i
            pairs2[3],  # Tv_i
            crit2,
            pairs2[5],  # TBv_j (and maybe hard-neg)
            t=args.nce_t,
        )
        video_loss = video_loss1 + video_loss2
        audio_loss = audio_loss1 + audio_loss2
        prob_vid = 0.5 * (prob_vid1 + prob_vid2)
        prob_aud = 0.5 * (prob_aud1 + prob_aud2)
        loss += 0.5 * video_loss2 + 0.5 * audio_loss2
    else:
        video_loss = video_loss1
        audio_loss = audio_loss1
        prob_vid = prob_vid1
        prob_aud = prob_aud1
    loss_dict = {
        f'{hyp}_video_loss': video_loss.item(),
        f'{hyp}_video_loss_og': video_loss1.item(),
        f'{hyp}_prob_vid':prob_vid.item(),
        f'{hyp}_prob_vid_og':prob_vid1.item(),
        f'{hyp}_audio_loss': audio_loss.item(),
        f'{hyp}_audio_loss_og': audio_loss1.item(),
        f'{hyp}_prob_aud':prob_aud.item(), 
        f'{hyp}_prob_aud_og':prob_aud1.item(), 
    }
    if pairs2:
        loss_dict[f'{hyp}_video_loss_tsf'] = video_loss2.item()
        loss_dict[f'{hyp}_prob_vid_tsf'] = prob_vid2.item()
        loss_dict[f'{hyp}_audio_loss_tsf'] =  audio_loss2.item()
        loss_dict[f'{hyp}_prob_aud_tsf'] = prob_aud2.item()
    return loss, loss_dict

def train_one_epoch(
        args,
        data_loader,
        model,
        crit_v,
        crit_a,
        optimizer,
        device,
        epoch,
        print_freq,
        lr_scheduler,
        apex=False,
        logger=None,
        writer=None,
):
    # Change from 8 frames to 32 frames during temporal finetune
    if epoch > args.epochs - args.finetune_epochs:
        args.clip_len = 30
        args.batch_size = 8

    model.train()
    metric_logger = utils.MetricLoggerGDT(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)
    for batch_idx, batch in metric_logger.log_every(data_loader, print_freq, header, logger, writer, 'train',
                                                    epoch=epoch):
        video, audio, _, _, _ = batch
        if batch_idx == 0:
            print_or_log((video.shape, audio.shape), logger=logger)
        start_time = time.time()

        if args.dualdata:
            video, video2 = torch.split(video, [3, 3], dim=1)
            audio, audio2 = torch.split(audio, [1, 1], dim=1)
            video2, audio2 = video2.to(device), audio2.to(device)

        video, audio = video.to(device), audio.to(device)

        # form positive and negative pairs dependent on hypothesis
        if args.twohead == 1:
            assert args.headcount > 1
            hyp1 = 'vtime'
            hyp2 = 'vasynced'

            feats1, feats2 = compute_feats(model, hyp1, video, audio, video2=video2, audio2=audio2)
            pairs1, pairs2 = get_pos_neg(hyp1, feats1, feats2, head=0)
            loss1, loss_dict1 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp1)

            feats1, feats2 = compute_feats(model, hyp2, video, audio, video2=video2, audio2=audio2, feats1=feats1)
            pairs1, pairs2 = get_pos_neg(hyp2, feats1, feats2, head=1)
            loss2, loss_dict2 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp2)

            loss = 0.25*(loss1 + loss2)
        elif args.twohead == 2:
            assert args.headcount > 1
            hyp1 = 'iasynced'
            hyp2 = 'vasynced'

            # same Transformation, i.e. only need to compute features once
            feats1, feats2 = compute_feats(model, hyp1, video, audio, video2=video2, audio2=audio2)

            pairs1, pairs2 = get_pos_neg(hyp1, feats1, feats2, head=0)
            loss1, loss_dict1 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp1)

            pairs1, pairs2 = get_pos_neg(hyp2, feats1, feats2, head=1)
            loss2, loss_dict2 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp2)

            loss = 0.25*(loss1 + loss2)
        elif args.twohead == 3:
            assert args.headcount > 1
            hyp1 = 'itime'
            hyp2 = 'vasynced'

            feats1, feats2 = compute_feats(model, hyp1, video, audio, video2=video2, audio2=audio2)
            pairs1, pairs2 = get_pos_neg(hyp1, feats1, feats2, head=0)
            loss1, loss_dict1 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp1)

            feats1, feats2 = compute_feats(model, hyp2, video, audio, video2=video2, audio2=audio2, feats1=feats1)
            pairs1, pairs2 = get_pos_neg(hyp2, feats1, feats2, head=1)
            loss2, loss_dict2 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp2)

            loss = 0.25*(loss1 + loss2)
        else:
            # basecase: base-loss with all-gpu negatives.
            hyp1 = 'base'
            hyp2 = utils.get_hyp_str(args)
            if args.dualdata:
                feats1 = compute_feats(model, hyp1, video, audio)
                pairs1, pairs2 = get_pos_neg(hyp1, feats1)
            else:
                feats1 = compute_feats(model, hyp1, video, audio)
                pairs1, pairs2 = get_pos_neg(hyp1, feats1)
            loss1, loss_dict1 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp1)

            if hyp2 != 'base':
                # hypothesis loss with less negatives
                feats1, feats2 = compute_feats(model, hyp2, video, audio, video2=video2, audio2=audio2, feats1=feats1)
                pairs1, pairs2 = get_pos_neg(hyp2, feats1, feats2, other=False)
                # compute loss
                loss2, loss_dict2 = get_losses(pairs1, pairs2, crit_a, crit_v, hyp=hyp2)
                loss = 0.25 * (loss1 +  loss2)
            else:
                loss2 = 0.0
                loss_dict2 = {}
                loss = 0.5 * (loss1 +  loss2)

        # Backward pass
        optimizer.zero_grad()
        if apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        # signal received, relaunch experiment
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            args.resume = 'True'
            if args.global_rank == 0:
                print_or_log("Beginning reqeue", logger=logger)
                utils.trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        for key in loss_dict1.keys():
            metric_logger.meters[key].update(loss_dict1[key], n=batch_size)
        if loss_dict2 is not None:
            for key in loss_dict2.keys():
                metric_logger.meters[key].update(loss_dict2[key], n=batch_size)
        if args.twohead == 1:
            for key in loss_dict3.keys():
                metric_logger.meters[key].update(loss_dict3[key], n=batch_size)
            for key in loss_dict4.keys():
                metric_logger.meters[key].update(loss_dict4[key], n=batch_size)
        metric_logger.meters['batch_t/s'].update((time.time() - start_time))
        metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))
    if args.distributed:
        dist.barrier()
    torch.cuda.empty_cache()
    return metric_logger.loss.avg


def main(args):
    # Set up mixed precision training
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError("Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training."
            )

    # Make output dir
    if args.model_name is None:
        model_name = f'av_GDT_{args.vid_base_arch}_{args.aud_base_arch}_epochs_{args.epochs}_bsz_{args.batch_size}_optim_SGD_lr_{args.lr}_scheduler_{args.use_scheduler}'
    else:
        model_name = args.model_name
    args.output_dir = os.path.join(args.output_dir, model_name)
    if args.output_dir:
        utils.mkdir(args.output_dir)

    # Init distributed mode
    if torch.cuda.is_available():
        utils.init_distributed_mode(args)

    # init signal handler
    utils.init_signal_handler()

    # Set up logger
    logger = None
    if args.distributed:
        filename = str(args.job_id) + '_' + str(args.global_rank) + '_log.out'
        logger = utils.setup_logger(
            "Video_reader, classification",
            args.output_dir,
            True,
            logname=filename
        )

    # Set up tensorboard
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    global_rank = args.global_rank if args.distributed else 0
    is_master = True if global_rank == 0 else False
    writer = None
    if is_master:
        writer = utils.setup_tbx(
            tbx_path,
            is_master
        )
        writer.add_text("namespace", repr(args))

    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    print_or_log(f"torchvision version: {torchvision.__version__}", logger=logger)

    # Set distributed mode
    device = torch.device(args.device)
    if args.world_size <= 8:
        device = 'cuda:0'

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Create model
    print_or_log("Creating model", logger=logger)
    # if args.sync_bn:
    #     args.mlptype = 1
    model = utils.load_model(
        model_name=args.model,
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        norm_feat=args.norm_feat,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount,
        use_max_pool=args.use_max_pool
    )
    model.to(device)
    if args.distributed and args.sync_bn:
        print_or_log("Sync BN on model", logger=logger)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        ngpus_per_node = torch.cuda.device_count()
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False
        )

    if args.aug_audio:
        if args.audio_augtype == 'mild':
            args.aug_audio = [1, 1, 2, 5]
        elif args.audio_augtype == 'medium':
            args.aug_audio = [1, 1, 3, 6]
        elif args.audio_augtype == 'heavy':
            args.aug_audio = [2, 2, 3, 6]

    # Warm up batch-norm
    if args.warmup_bn and not args.resume:
        print_or_log(f'Warming up BN', logger=logger)
        dataset, _dl = utils.get_dataloader(args, 0)
        utils._warmup_batchnorm(args, model, dataset, device, batches=100)
        del dataset, _dl

    # Set up training optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # For Mixed Precision training
    if args.apex:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.apex_opt_level
        )

    # Set up LR scheduler
    milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
    lr_scheduler = None
    if args.use_scheduler:
        if args.lr_warmup_epochs > 0:
            if args.scheduler_type == 'multi_step':
                print_or_log(f'Using Multi-Step LR scheduler', logger=logger)
                scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=milestones,
                    gamma=args.lr_gamma
                )
            else:
                print_or_log(f'Using Cosine Annealing LR scheduler', logger=logger)
                scheduler_step = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
            lr_scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=args.world_size,
                total_epoch=args.lr_warmup_epochs,
                after_scheduler=scheduler_step
            )
        else:
            if args.scheduler_type == 'multi_step':
                print_or_log(f'Using Multi-Step LR scheduler', logger=logger)
                lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=milestones,
                    gamma=args.lr_gamma
                )
            else:
                print_or_log(f'Using Cosine Annealing LR scheduler', logger=logger)
                lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    # Load criterions
    criterion_video = torch.nn.CrossEntropyLoss().to(device)
    criterion_audio = torch.nn.CrossEntropyLoss().to(device)

    # Checkpointing restart
    to_restore = {'epoch': 0}
    utils.restart_from_checkpoint(
        args,
        run_variables=to_restore,
        model=model,
        optimizer=optimizer,
    )
    args.start_epoch = to_restore['epoch']
    if args.use_scheduler:
        [lr_scheduler.step() for _ in range(to_restore['epoch'])]

    # Set LR if temporal finetuning (1 node)
    if args.world_size <= 8:
        print_or_log(f'World size: {args.world_size}, lr: {args.lr}', logger=logger)
        lr = args.lr
        for pg in optimizer.param_groups:
            pg['lr'] = lr

    # Create dataloader
    ds = utils.get_ds(args, 0)

    print("Creating data loaders", flush=True)
    train_sampler = None 
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(ds) 

    data_loader = torch.utils.data.DataLoader(
        ds, 
        batch_size=args.batch_size,
        sampler=train_sampler, 
        num_workers=args.workers,
        pin_memory=True, 
        collate_fn=None,
        drop_last=True
    )

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        print_or_log(f'Start training epoch: {epoch}', logger=logger)
        loss = train_one_epoch(
            args,
            data_loader,
            model,
            criterion_video,
            criterion_audio,
            optimizer,
            device,
            epoch,
            args.print_freq,
            lr_scheduler,
            args.apex,
            logger=logger,
            writer=writer,
        )
        if lr_scheduler:
            lr_scheduler.step()
        if args.output_dir:
            utils.save_checkpoint(args, epoch, model, optimizer, lr_scheduler)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print_or_log(f'Training time {total_time_str}', logger=logger)


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
                         'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    # AUDIO UTILS
    parser.add_argument(
        '--aud-sample-rate',
        default=48000,
        type=int,
        help='audio sample rate'
    )
    parser.add_argument(
        '--aud-spec-type',
        default=1,
        type=int,
        help='audio spec type' # 1 : (40, 99), (257, 199)
    )
    parser.add_argument(
        '--use-volume-jittering',
        default='False',
        type='bool',
        help='use volume jittering'
    )
    parser.add_argument(
        '--use-temporal-jittering',
        default='False',
        type='bool',
        help='use temporal jittering'
    )
    parser.add_argument(
        '--num-sec',
        default=1,
        type=int,
        help='Number of seconds'
    )
    parser.add_argument(
        '--z-normalize',
        default='False',
        type='bool',
        help='normalize audio'
    )
    
    parser.add_argument(
        '--twohead',
        default=0,
        type=int,
        metavar='TH',
        help="twohead: 0:(basecase), 1:(variant-time,variant-async), 2:(invariant-async,variant-async), 3:(invariant-time,variant-async)"
    )
    parser.add_argument(
        '--arrowtime',
        default=0,
        type=int,
        metavar='AT',
        help='arrowtime: 0:(basecase), 1:(reversed==postive), -1:(reversed==additional negative)'
    )
    parser.add_argument(
        '--asynced',
        default=0,
        type=int,
        metavar='SY',
        help='asynced: 0:(basecase), 1:(asynced==postive), -1:(asynced==additional negative)'
    )
    parser.add_argument(
        '--dualdata',
        type='bool',
        default='True',
        help='use dataloader that returns two samples per video'
    )
    parser.add_argument(
        '--stochastic-block',
        type='bool',
        default='False',
        help='use stochastic blocking'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=1,
        help='how many heads each modality has'
    )
    parser.add_argument(
        '--nce-t',
        type=float,
        default=0.07,
        help='softmax weighting'
    )
    parser.add_argument(
        '--num-negatives',
        default=-1,
        type=int,
        help='number of negatives in contrastive loss'
    )

    ### DATA
    parser.add_argument(
        '--dataset',
        default='kinetics',
        help='name of dataset'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool',
        help='Apply random color jitter'
    )
    parser.add_argument(
        '--use-scale-jittering',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int,
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--aug-audio',
        default='False',
        type='bool',
        help='whether to augment audio'
    )

    parser.add_argument(
        '--audio-augtype',
        default='mild',
        type=str,
        choices=['na', 'mild', 'medium', 'heavy'],
        help='type of audio-augment default: mild'
    )

    parser.add_argument(
        '--sample-aud-ind',
        default='False',
        type='bool',
        help='whether to sample audio independently'
    )
    parser.add_argument(
        '--num-data-samples',
        default=None,
        type=int,
        help='number of samples in dataset'
    )
    parser.add_argument(
        '--use-temp-jitter',
        default='True',
        type='bool',
        help='Get clips from random timestamps each epoch'
    )
    parser.add_argument(
        '--center-crop',
        default='False',
        type='bool',
        help='Use center cropping instead of random cropping'
    )
    parser.add_argument(
        '--fold',
        default=1,
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '--clip-len',
        default=30,
        type=int,
        help='number of frames per clip'
    )
    parser.add_argument(
        '--target-fps',
        default=30,
        type=int,
        help='target fps'
    )
    parser.add_argument(
        '--clips-per-video',
        default=1,
        type=int,
        help='number of clips to sample from video'
    )
    parser.add_argument(
        '-j', '--workers',
        default=0,
        type=int,
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--train-crop-size',
        default=112,
        type=int,
        help='Size of spatial crops'
    )
    parser.add_argument(
        '--sample-rate',
        default=1,
        type=int,
        help='Subsampling rate: num frames between clips'
    )

    ### MODEL
    parser.add_argument(
        '--model',
        default='av_gdt',
        help='model',
        choices=['av_gdt']
    )
    parser.add_argument(
        '--vid-base-arch',
        default='r2plus1d_18',
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch',
        default='vgg_audio',
        help='Audio Base Arch for A-V model',
        choices=['resnet9', 'resnet18', 'vgg_audio', 'resnet34', 'resnet50']
    )
    parser.add_argument(
        "--pretrained",
        type='bool',
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp',
        default='True',
        type='bool',
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--use-max-pool',
        default='False',
        type='bool',
        help='Use max pool instead of GAP'
    )
    parser.add_argument(
        '--mlptype',
        default=0,
        type=int,
        help='MLP type (default: 0)'
    )

    ### TRAINING
    parser.add_argument(
        '-b', '--batch-size',
        default=4,
        type=int
    )
    parser.add_argument(
        '--epochs',
        default=45,
        type=int,
        metavar='N',
        help='number of total epochs to run'
    )
    parser.add_argument(
        '--finetune-epochs',
        default=0,
        type=int,
        help='number of epochs to finetune on (32, 224, 224)'
    )
    parser.add_argument(
        '--lr',
        default=0.01,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--use-linear-scaling',
        default='False',
        type='bool',
        help='Linearly scale learning rate'
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        metavar='M',
        help='momentum'
    )
    parser.add_argument(
        '--wd', '--weight-decay',
        default=1e-4,
        type=float,
        metavar='W',
        help='weight decay (default: 1e-4)',
        dest='weight_decay'
    )
    parser.add_argument(
        "--use-scheduler",
        type='bool',
        default='True',
        help="Use LR scheduler",
    )
    parser.add_argument(
        "--scheduler-type",
        type=str,
        default='multi_step',
        choices=['multi_step', 'cosine'],
        help="Type of LR scheduler",
    )
    parser.add_argument(
        '--lr-milestones',
        default='20,30,40',
        type=str,
        help='decrease lr on milestones'
    )
    parser.add_argument(
        '--lr-gamma',
        default=0.1,
        type=float,
        help='decrease lr by a factor of lr-gamma'
    )
    parser.add_argument(
        '--lr-warmup-epochs',
        default=0,
        type=int,
        help='number of warmup epochs'
    )
    parser.add_argument(
        "--sync-bn",
        type='bool',
        default='False',
        help="Use sync batch norm",
    )
    parser.add_argument(
        "--warmup-bn",
        type='bool',
        default='False',
        help="Warmup batchnorm",
    )
    parser.add_argument(
        "--norm-feat",
        type='bool',
        default='True',
        help="Normalize embeddings",
    )

    ### LOGGING
    parser.add_argument(
        '--print-freq',
        default=10,
        type=int,
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir',
        default='.',
        help='path where to save'
    )
    parser.add_argument(
        '--model-name',
        default=None,
        help='exp desc'
    )

    ### CHECKPOINTING
    parser.add_argument(
        '--resume',
        type='bool',
        default='False',
        help='resume from checkpoint'
    )
    parser.add_argument(
        '--start-epoch',
        default=0, type=int,
        metavar='N',
        help='start epoch'
    )

    # Mixed precision training parameters
    parser.add_argument(
        '--apex',
        type='bool',
        default='False',
        help='Use apex for mixed precision training'
    )
    parser.add_argument(
        '--apex-opt-level',
        default='O1',
        type=str,
        help='For apex mixed precision training'
             'O0 for FP32 training, O1 for mixed precision training.'
             'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
    )

    # distributed training parameters
    parser.add_argument(
        '--device',
        default='cuda',
        help='device'
    )
    parser.add_argument(
        '--distributed',
        type='bool',
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend',
        default='nccl',
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url',
        default='env://',
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size',
        default=1,
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug_slurm',
        type='bool',
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local_rank',
        default=-1,
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master_port',
        default=-1,
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    # set multi-processing start method
    import torch.multiprocessing as mp

    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    main(args)
