import argparse
import importlib
import logging
import sys

from postprocess_path import set_save_path
from utils import Logger, pprint, set_gpu, set_logging, set_seed


def get_command_line_parser():
    parser = argparse.ArgumentParser()
    # about dataset and network
    parser.add_argument('-project', type=str, default='base', choices=['teen', 'cif','meta_adapter'])
    parser.add_argument('-dataset', type=str, default='cifar100',
                        choices=['mini_imagenet', 'cub200', 'cifar100'])
    parser.add_argument('-dataroot', type=str, default='')
    parser.add_argument('-temperature', type=float, default=0.1)
    parser.add_argument('-temperature_sup', type=float, default=0.1)
    parser.add_argument('-feat_norm', action='store_true', help='If True, normalize the feature.')

    # about pre-training
    parser.add_argument('-epochs_base', type=int, default=100)
    parser.add_argument('-epochs_base_fast', type=int, default=100)
    parser.add_argument('-epochs_pretrain', type=int, default=100)
    parser.add_argument('-epochs_postrain', type=int, default=100)
    parser.add_argument('-epochs_meta', type=int, default=5)
    parser.add_argument('-fast_adaptation_steps', type=int, default=5)
    parser.add_argument('-epochs_new', type=int, default=100)
    parser.add_argument('-lr_base', type=float, default=0.1)
    parser.add_argument('-meta_step_size', type=float, default=0.1)
    parser.add_argument('-lr_new', type=float, default=0.1)
    parser.add_argument('-lr_meta', type=float, default=0.1)
    parser.add_argument('-noise_scale', type=float, default=0.01)
    parser.add_argument('-proto_loss_scale', type=float, default=0.5)
    parser.add_argument('-rho', type=float, default=1e-3)
    parser.add_argument('-ssl_weight', type=float, default=0.1)
    parser.add_argument('-balance', type=float, default=1.0)
    parser.add_argument('-loss_iter', type=int, default=200)
    parser.add_argument('-alpha', type=float, default=2.0)
    parser.add_argument('-eta', type=float, default=0.1)
    parser.add_argument('-warmup_epochs_base', type=int, default=3)
    parser.add_argument('-cosMargin', type=float, default=0.0)
    parser.add_argument('-average_cosMargin', type=float, default=0.0)
    parser.add_argument('-class_relation', default='None', type=str, choices=['wg', 'feat', 'None'])
    parser.add_argument('-in_domain_feat_cls_weight', default=0.0, type=float)
    parser.add_argument('-backbone_feat_cls_weight', default=1.0, type=float)
    parser.add_argument('-in_domain_feat_cosMargin', default=0.0, type=float)
    parser.add_argument('-in_domain_average_cosMargin', default=0.0, type=float)
    parser.add_argument('-in_domain_class_relation', default='None', type=str, choices=['wg', 'feat', 'None'])

    parser.add_argument('-dropout_rate', default=0.0, type=float)
    parser.add_argument('-in_domain_dropout_rate', default=0.0, type=float)
    parser.add_argument('-in_domain_feat_dim', default=-1, type=int)

    ## optimizer & scheduler
    parser.add_argument('-optim', type=str, default='sgd', choices=['sgd', 'adam'])
    parser.add_argument('-schedule', type=str, default='Step', choices=['Step', 'Milestone', 'Cosine'])
    parser.add_argument('-milestones', nargs='+', type=int, default=[60, 70])
    parser.add_argument('-step', type=int, default=20)
    parser.add_argument('-decay', type=float, default=0.0005)
    parser.add_argument('-decay_incre', type=float, default=0.05)
    parser.add_argument('-decay_meta', type=float, default=0.05)
    parser.add_argument('-momentum', type=float, default=0.9)
    parser.add_argument('-gamma', type=float, default=0.1)
    parser.add_argument('-lam', type=float, default=0.01)
    parser.add_argument('-w_a', type=float, default=0.2)
    parser.add_argument('-w_kd', type=float, default=1.0)
    parser.add_argument('-pseudo_way', type=int, default=5)
    parser.add_argument('-pseudo_shot', type=int, default=5)
    parser.add_argument('-incremental_shot', type=int, default=5)
    parser.add_argument('-tmax', type=int, default=600)  # consine scheduler
    parser.add_argument('-batch_size_sup_con', type=int, default=128)
    parser.add_argument('-drop_last_batch', action="store_true",
                        help="Drops the last batch if not equal to the assigned batch size")

    parser.add_argument('-not_data_init', action='store_true', help='using average data embedding to init or not')
    parser.add_argument('-batch_size_base', type=int, default=128)
    parser.add_argument('-batch_size_new', type=int, default=0,
                        help='set 0 will use all the availiable training image for new')
    parser.add_argument('-test_batch_size', type=int, default=100)
    parser.add_argument('-base_mode', type=str, default='ft_cos',
                        choices=['ft_dot',
                                 'ft_cos',
                                 'meta_cos','meta_dot'])  # ft_dot means using linear classifier, ft_cos means using cosine classifier
    parser.add_argument('-new_mode', type=str, default='avg_cos',
                        choices=['ft_dot', 'ft_cos',
                                 'avg_cos',
                                 'meta_cos',
                                 'meta_dot'])  # ft_dot means using linear classifier, ft_cos means using cosine classifier, avg_cos means using average data embedding and cosine classifier

    # for our new method
    parser.add_argument("-num_crops", type=int, default=[2, 4], nargs="+",
                        help="amount of crops")
    parser.add_argument('--num_proj_layers', type=int, default=2, help='number of projection layer')
    parser.add_argument('-rand_aug_sup_con', action='store_true', help='')
    parser.add_argument('-prob_color_jitter', type=float, default=0.8)
    parser.add_argument('-min_crop_scale', type=float, default=0.2)
    parser.add_argument("-size_crops", type=int, default=[224, 96], nargs="+",
                        help="resolution of inputs")
    parser.add_argument("-min_scale_crops", type=float, default=[0.14, 0.05], nargs="+",
                        help="min area of crops")
    parser.add_argument("-max_scale_crops", type=float, default=[1, 0.14], nargs="+",
                        help="max area of crops")
    parser.add_argument('-constrained_cropping', action='store_true',
                        help='condition small crops on key crop')
    parser.add_argument('-auto_augment', type=int, default=[], nargs='+',
                        help='Apply auto-augment 50 % of times to the selected crops')
    parser.add_argument('-fantasy', type=str, default='rotation', help='fantasy method to generate virtual classes')

    parser.add_argument('-start_session', type=int, default=0)
    parser.add_argument('-model_dir', type=str, default=None, help='loading model parameter from a specific dir')
    parser.add_argument('-meta_model_dir', type=str, default=None,
                        help='loading meta model parameter from a specific dir')
    parser.add_argument('-only_do_incre', action='store_true', help='Load model and incremental learning...')
    parser.add_argument('-metabase', action='store_true', help='Meta learning...')
    parser.add_argument('-no_SAM', action='store_true', help='')
    parser.add_argument('-no_MIS', action='store_true', help='')
    parser.add_argument('-meta_train_is_done', action='store_true', help='Meta learning...')

    # about training
    parser.add_argument('-gpu', default='0,1,2,3')
    parser.add_argument('-num_workers', type=int, default=8)
    parser.add_argument('-seed', type=int, default=1)
    parser.add_argument('-debug', action='store_true')

    return parser


# def add_commond_line_parser(params):
#     project = params[1]
#     # base parser
#     parser = get_command_line_parser()
#
#     if project == 'base':
#         args = parser.parse_args(params[2:])
#         return args
#
#     elif project == 'teen':
#         parser.add_argument('-softmax_t', type=float, default=16)
#         parser.add_argument('-shift_weight', type=float, default=0.5, help='weights of delta prototypes')
#         parser.add_argument('-soft_mode', type=str, default='soft_proto',
#                             choices=['soft_proto', 'soft_embed', 'hard_proto'])
#         args = parser.parse_args(params[2:])
#         return args
#     elif project == 'cif':
#         args = parser.parse_args(params[2:])
#         return args
#     else:
#         raise NotImplementedError
def add_commond_line_parser(params):
    project = params[1]
    parser = get_command_line_parser()

    if project == 'teen' or project == 'cif' or project == 'meta_adapter':
        parser.add_argument('-softmax_t', type=float, default=16)
        parser.add_argument('-shift_weight', type=float, default=0.5, help='weights of delta prototypes')
        parser.add_argument('-soft_mode', type=str, default='soft_proto',
                            choices=['soft_proto', 'soft_embed', 'hard_proto'])

    args = parser.parse_args(params[2:])
    return args


if __name__ == '__main__':

    args = add_commond_line_parser(sys.argv)

    set_seed(args.seed)
    pprint(vars(args))
    args.num_gpu = set_gpu(args)

    set_save_path(args)

    logger = Logger(args, args.save_path)
    set_logging('INFO', args.save_path)
    logging.info(f"save_path: {args.save_path}")
    trainer = importlib.import_module('models.%s.fscil_trainer' % (args.project)).FSCILTrainer(args)
    trainer.train()
