import argparse
import os
import time

import torch

import utils.misc as misc


def setup_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Continual Learning Training')

    # data/setting
    parser.add_argument('--data', default='', metavar='DIR', help='path to dataset')
    parser.add_argument('--dataset', default='i21k_unique', type=str,
                        choices=['ImageNet10k', 'cglm'], help='')
    parser.add_argument('--steps', default=500, type=int, metavar='N', help='')
    parser.add_argument('--setting', default='class_incremental', type=str, help='')
    parser.add_argument('--label_ratio', default=0.5, type=float, help='')
    parser.add_argument('--split', default=10, type=int, metavar='N', help='')

    # output
    parser.add_argument('--output', default='', metavar='DIR',
                        help='path to output')
    parser.add_argument('--output_dir', default='', metavar='DIR', help='path to output')
    parser.add_argument('--run-name', default='')
    parser.add_argument('--debug', dest='debug', action='store_true', help='')

    # backbone
    parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
    parser.add_argument('--pretrained_model', metavar='DIR', help='use pre-trained model')
    parser.add_argument('--model-name', default='mae_vit_base_patch16_dec512d8b', choices=['mae_vit_base_patch16_dec512d8b','mae_vit_base_patch16_dec512d1b'])


    # load data
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4) per GPU')

    # training
    parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ')
    parser.add_argument('--resume', dest='resume', action='store_true', help='')
    parser.add_argument('--resume_dir', default='')

    parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='')
    parser.add_argument('--accum_bs', default=1024, type=int, metavar='N', help='')
    parser.add_argument('--accum', default=None, type=int, metavar='N', help='')
    parser.add_argument('--accum_bs_s1', default=4096, type=int, metavar='N', help='')
    parser.add_argument('--blr', default=5e-4, type=float, help='initial ft learning rate')
    parser.add_argument('--min_lr', default=1e-4, type=float, help='initial ft learning rate')
    parser.add_argument('--lr_extra_rate', default=0.1, type=float, help='initial ft learning rate')

    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=0.05, type=float, metavar='W',
                        help='weight decay (default: 0.05)', dest='weight_decay')
    parser.add_argument('--layer-decay', default=0.65, type=float)

    # continual leanring mechanism
    parser.add_argument('--unsup_loss', dest='unsup_loss', action='store_true')
    parser.add_argument('--unlabeled_selection', dest='unlabeled_selection', action='store_true')
    parser.add_argument('--replay_first', dest='replay_first', action='store_true')
    parser.add_argument('--mask_cur_loss', dest='mask_cur_loss', action='store_true')
    parser.add_argument('--mask_ratio', default=0.75, type=float, help='')
    parser.add_argument('--size_replay_buffer', default=-1, type=int,
                        help='size of the experience replay buffer (per gpu, so if you have 8 gpus, each gpu will have size_replay_buffer number of samples in the buffer)')
    parser.add_argument('--sampling', default='', type=str, help='')
    parser.add_argument('--method', default='supmae', type=str, help='')


    # hyperparameters
    parser.add_argument('--min_budget', default=500, type=int, help='')
    parser.add_argument('--batch_split', default=0.3, type=float, help='ratio of unlabeled data for every batch')
    parser.add_argument('--batch_split_s1', default=0.9, type=float, help='ratio of unlabeled data for every batch')
    parser.add_argument('--mem_sampling_rate', default=0.4, type=float, help='ratio of buffer data for every batch')
    parser.add_argument('--mem_sampling_rate_s1', default=0.05, type=float, help='ratio of buffer data for every batch')
    parser.add_argument('--mem_ecof_s1', default=1.0, type=float, help='ratio of buffer data for every batch')
    parser.add_argument('--labeled_coef', default=1, type=float)
    parser.add_argument('--unlabeled_coef', default=50.0, type=float)
    parser.add_argument('--unlabel_distill_coef', default=50.0, type=float)


    # evaluate
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--ablation', dest='ablation', action='store_true',
                        help='')
    parser.add_argument('--eval-freq', default=1e6, type=int, metavar='N', help='validation frequency (default: 100)')
    parser.add_argument('--no-eval', dest='no_evaluate', action='store_true', help='')
    parser.add_argument('--light-eval', dest='light_eval', action='store_true', help='')

    # log
    parser.add_argument('-p', '--print-freq', default=50, type=int, metavar='N', help='print frequency (default: 50)')
    parser.add_argument('--val-print-freq', default=200, type=int, metavar='N', help='print frequency (default: 50)')
    parser.add_argument('--wandb_log', dest='wandb_log', action='store_true', help='')

    # distributed
    parser.add_argument('--world-size', default=1, type=int, help='')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://127.0.0.1:50272', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--dist_eval', dest='dist_eval', action='store_true', help='')
    parser.add_argument('--ngpus_per_node', default=None, type=int,
                        help='set this when use only partial of gpus of a node')
    parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
    parser.add_argument('--multiprocessing-distributed', action='store_true', dest="multiprocessing_distributed",
                        help='Use multi-processing distributed training to launch '
                             'N processes per node, which has N GPUs. This is the '
                             'fastest way to use PyTorch for either single node or '
                             'multi node data parallel training')

    args = parser.parse_args()

    # adapt training budget according to gpus
    if not args.ngpus_per_node:
        args.ngpus_per_node = torch.cuda.device_count()
    args.batch_size_total = args.ngpus_per_node * args.batch_size


    # data setting
    args.input_size = 224


    if args.accum is None:
        if args.batch_size_total < args.accum_bs:
            args.accum = int(args.accum_bs / args.batch_size_total)
        else:
            args.accum = 1
    eff_batch_size = args.batch_size_total * args.accum * args.batch_split
    args.lr = args.blr * eff_batch_size / 256
    args.warmup = 0
    args.iters = int(args.steps * 1024 / args.batch_size_total)
    args.lr_extra = args.lr *args.lr_extra_rate

    # setup output dir
    if not args.output_dir:
        args.output_dir = misc.create_output_dir(args)
    os.makedirs(args.output_dir, exist_ok=True)
    print("{}".format(args).replace(', ', ',\n'))

    # for log
    args.start_time = time.time()

    return args
