# Copyright (c) Winci.
# Licensed under the Apache License, Version 2.0 (the "License");

import argparse

def get_default_params(arch):
    if "vit" in arch:
        return {"optimizer": 'adamw', "lr": 5e-4, "final_lr": 1e-4, "wd": 0.1, "warmup_epochs": 40}
    else:
        return {"optimizer": 'sgd', "lr": 0.5, "wd": 1e-5, "warmup_epochs": 2}

def get_args():
    parser = argparse.ArgumentParser(description="Implementation of ReSA")

    parser.add_argument("--dump_path", type=str, default=".",
                    help="experiment dump path for checkpoints and log")

    parser.add_argument('--seed', default=None, type=int,
                    help='random seed for initializing training.')

    #####################
    #### data params ####
    #####################
    parser.add_argument("--data_path", type=str, default="/path/to/imagenet",
                    help="path to dataset repository")

    parser.add_argument("--crops_nmb", type=int, default=[1], nargs="+",
                    help="list of number of crops (example: [1, 10])")
                    
    parser.add_argument("--crops_size", type=int, default=[224], nargs="+",
                    help="crops resolutions (example: [224, 96])")
    
    parser.add_argument("--crops_min_scale", type=float, default=[0.2], nargs="+",
                    help="minimum scale of the crops (example: [0.32, 0.05])")

    parser.add_argument("--crops_max_scale", type=float, default=[1.], nargs="+",
                    help="maximum scale of the crops (example: [1., 0.32])")

    parser.add_argument("--solarization_prob", type=float, default=[0.2], nargs="+",
                    help="solarization prob (example: [0.2, 0.0])")

    parser.add_argument("--size_dataset", type=int, default=-1, 
                    help="size of dataset, -1 indicates the full dataset")

    parser.add_argument("--workers", default=8, type=int,
                    help="number of data loading workers per gpu")
    
    ############################
    ### resa specific params ###
    ############################
    parser.add_argument("--temperature", default=0.4, type=float,
                    help="temperature parameter in training loss")

    parser.add_argument("--momentum", type=float, default=0.996, 
                    help="Base EMA parameter")

    #####################
    #### optim params ###
    #####################
    parser.add_argument("--epochs", default=100, type=int,
                    help="number of total epochs to run")

    parser.add_argument("--batch_size", default=256, type=int,
                    help="batch size per gpu, i.e. how many unique instances per gpu")

    parser.add_argument('--lr', default=None, type=float, 
                    help='initial (base) learning rate for train')
    
    parser.add_argument('--final_lr', default=None, type=float, 
                    help='final learning rate for train')

    parser.add_argument('--wd', default=None, type=float, 
                    help='weight decay for train')
    
    parser.add_argument("--optimizer", type=str, choices=["sgd","adamw"], default=None, 
                    help="optimizer")

    parser.add_argument("--warmup_epochs", default=None, type=int, 
                    help="number of warmup epochs")

    ####################
    #### dist params ###
    ####################
    parser.add_argument("--world_size", default=1, type=int, 
                    help="""number of processes: it is set automatically and
                            should not be passed as argument""")

    parser.add_argument("--rank", default=0, type=int, 
                    help="rank of this process: it is set automatically and should not be passed as argument")

    parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
                    
    parser.add_argument("--no-set-device-rank", default=False, action="store_true",
                    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).")

    parser.add_argument("--dist-url", default="env://", type=str,
                    help="url used to set up distributed training")

    ############################
    #### architecture params ###
    ############################
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                    help='model architecture (e.g. resnet18, resnet50, vit_small, vit_base)')

    parser.add_argument('--patch_size', default=16, type=int, 
                    help='Patch resolution of the vision transformer.')

    parser.add_argument("--mlp_layers", type=int, default=3, 
                    help="number of FC layers in projector")

    parser.add_argument("--mlp_dim", type=int, default=2048, 
                    help="size of FC layers in projector/predictor")

    parser.add_argument("--no_pred", dest="pred", action="store_false", 
                    help="do not use an extra predictor")

    parser.add_argument("--emb", type=int, default=2048, 
                    help="embedding dimension of the projector")

    parser.add_argument("--drop_path", type=float, default=0., 
                    help="Stochastic Depth")
    
    ########################
    #### evaluate params ###
    ########################
    parser.add_argument('--train_percent', default=100, type=int,
                    choices=(100, 10, 1),
                    help='size of traing set in percent')

    parser.add_argument('--num_classes', default=1000, type=int,   
                    help='number of classes')

    parser.add_argument('--weights', default='freeze', type=str,
                    choices=('finetune', 'freeze'),
                    help='finetune or freeze pretrained encoder weights')

    parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
                    help='path to checkpoint for evaluation(default: none)')
    
    parser.add_argument('--ckpt_from_impre', default=None, type=str, metavar='PATH',
                    help='path to checkpoint from imagenet pretraining')

    parser.add_argument('--lr_encoder', default=0.0002, type=float, metavar='LR',
                    help='encoder base learning rate')

    parser.add_argument('--lr_classifier', default=40, type=float, metavar='LR',
                    help='classifier base learning rate')

    parser.add_argument("--scheduler", type=str, default="step", choices=('step', 'cos'),
                    help="learning rate scheduler")

    parser.add_argument('--n_last_blocks', default=4, type=int, 
                    help="""Concatenate [CLS] tokens for the `n` last blocks. 
                            We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")

    parser.add_argument('--avgpool_patchtokens', default=False, action="store_true",
                    help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
                            We typically set this to False for ViT-Small and to True with ViT-Base.""")

    parser.add_argument('--use_cuda', default=True,
                    help="""Should we store the features on GPU in knn evaluation? 
                            We recommend setting this to False if you encounter OOM""")

    parser.add_argument('--clip_fusion', type=str, default=None,
                    help="clip fusion method")
    
    parser.add_argument('--is_parts', type=str, default=None,
                    help="使用什么part clustering")
    parser.add_argument('--part_method', type=str, default="_global_part", choices=[None,"_part","_global","_global_part"],
                    help="是否使用part global part_global")
    parser.add_argument('--n_parts', type=int, default=4,
                    help="number of parts")
    parser.add_argument('--attn', type=bool, default=False,
                    help="是否使用attention帮助")
    parser.add_argument("--text_path", type=str, default=None,
                    help="文本信息的读取")
    parser.add_argument("--with_texts", type=str, default=None,
                    help="是否使用text文本信息")
    parser.add_argument("--qa_idx", type=int, default=2,
                    help="文本信息的读取")
    parser.add_argument("--is_recon", type=str, default=None,
                    help="是否使用重建损失，将得到的text作为目标，重构img出来的特征")
    
    ########################
    #### t-SNE params ###
    ########################
    parser.add_argument("--tsne_perplexity", default=30, type=int,
                    help="t-SNE perplexity parameter")
    
    parser.add_argument("--max_points_per_superclass", default=2000, type=int,
                    help="maximum number of points to plot per superclass")

    # freeze_netvlad
    # parser.add_argument("--freeze_netvlad", default=False, type=bool,
    #                 help="freeze netvlad")
    # parser.add_argument("--netvlad_lr", default=None, type=float,
    #                 help="netvlad learning rate")
    # parser.add_argument("--netvlad_final_lr", default=None, type=float,
    #                 help="netvlad final learning rate")
    # parser.add_argument("--netvlad_warmup_epochs", default=None, type=int,
    #                 help="netvlad warmup epochs")
    
    args = parser.parse_args()

    default_params = get_default_params(args.arch)
    for name, val in default_params.items():
        if getattr(args, name) is None:
            setattr(args, name, val)

    return args