import sys
import argparse
import pprint
import warnings

from CIL.YoooP import YoooP
from CIL.utils.setting import Setting
from CIL.utils import str2bool
from CIL.utils.logging import ReDirectSTD
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp


class Configs(object):
    def __init__(self):
        parser = argparse.ArgumentParser()
        # ==> Global setting
        parser.add_argument('-d', '--sys_device_ids', type=eval, default=(0,))
        # put the two datasets under the same path
        parser.add_argument('--dataset_path', type=str, default='')
        parser.add_argument('--dataset', type=str, default='CIFAR-100', choices=['CIFAR-100', 'TinyImageNet', 'ImageNet-100', 'ImageNet-1000'])
        # ISLVRC2012 = ImageNet-1000; tiny-imagenet = 200 classes
                
        # ==> Preprocess setting
        parser.add_argument('--resize_h_w', type=eval, default=(256, 256))
        parser.add_argument('--crop_prob', type=float, default=0.0)
        parser.add_argument('--crop_ratio', type=float, default=1.0)
        # rotate_prob for random erasing
        # rotate_prob and mirror_type only for train dataset
        parser.add_argument('--rotate_prob', type=float, default=0.4)
        parser.add_argument('--mirror_type', type=str, default='random', choices=['none', 'random', 'always'])
        
        # ==> train dataset setting
        parser.add_argument('--threads', type=int, default=16)
        # class mumber for each task
        parser.add_argument('--init_class_number', type=int, default=10)
        parser.add_argument('--phase', type=int, default=10)
        parser.add_argument('--ids_per_batch', type=int, default=16)
        parser.add_argument('--ims_per_id', type=int, default=4)
        parser.add_argument('--init_ids_per_batch', type=int, default=16)
        parser.add_argument('--init_ims_per_id', type=int, default=4)

        # ==> test dataset setting
        parser.add_argument('--test_batch_size', type=int, default=64)

        # ==> train/test process setting
        # parser.add_argument('--meta_epoch', type=int, default=50)
        parser.add_argument('--resume', type=str2bool, default=False)
        parser.add_argument('--model', type=str, default='Resnet18')
        parser.add_argument('--epochs_per_val', type=int, default=2)
        # for backbone per task
        parser.add_argument('--epochs_per_task', type=int, default=80)
        parser.add_argument('--backbone_t', type=float, default=0.0)

        # for distill per task
        parser.add_argument('--update_scale', type=int, default=1)
        parser.add_argument('--distill_factor', type=float, default=1.0)

        # only be useful if u only have one GPU, and batch size is small
        parser.add_argument('--mini_batch', type=int, default=1)

        # only be useful if the model need to embedding the feature
        parser.add_argument('--model_embeding_size', type=int, default=0)

        # ==> log setting
        parser.add_argument('--steps_per_log', type=int, default=5)

        # ==> backbone lr setting
        parser.add_argument('--base_lr', type=float, default=0.01)
        parser.add_argument('--lr_decay_type', type=str, default='exp',
                            choices=['exp', 'staircase', 'warmup', 'epochs'])
        # exp strategy
        parser.add_argument('--exp_decay_at_epoch', type=int, default=21)
        # warm up strategy
        parser.add_argument('--staircase_decay_at_epochs',
                            type=eval, default=(20, 60, 120, 200))
        # warup,epochs,staircase strategy factor
        parser.add_argument('--decay_factor', type=float, default=0.1)
        # epochs strategy
        parser.add_argument('--epoch_decay_steps', type=float, default=10)

        # set for different parameters groups
        # -> lr_status set for f_new_parameters
        parser.add_argument('--lr_satus', type=float, default=1.5)

        # ==> loss setting
        parser.add_argument('--backbone_feature_loss', type=str, default='msaloss',
                            choices=['arcface', 'softmax_nn', 'KLloss', 'JSloss', 'msaloss', 'CEloss'])
        parser.add_argument('--backbone_local_loss', type=str, default='KLloss', choices=['KLloss', 'JSloss','ammloss', 'dsam'])
        parser.add_argument('--backbone_local_loss_weight', type=float, default=1.0)
        # -> for arcface loss, and msa loss
        parser.add_argument('--arcface_s', type=int, default=16)
        parser.add_argument('--arcface_m2', type=float, default=0.0)
        parser.add_argument('--arcface_m3', type=float, default=0.0)
        # -> for local loss: amm loss
        parser.add_argument('--amm_part', type=str, default='both', choices=['both', 'pos', 'neg'])
        parser.add_argument('--margin', type=float, default=0.8)

        # ==> file setting
        # -> log file path
        parser.add_argument('--exp_dir', type=str, default='')
        # -> weight save path
        parser.add_argument('--train_path', type=str, default='')
        # -> only for test
        parser.add_argument('--test_path', type=str, default='')

        # ==> execulator
        parser.add_argument('--phase_train', type=str2bool, default=False)
        parser.add_argument('--up2now_test', type=str2bool, default=False)

        self.args = parser.parse_args()


def main(args):
    cudnn.benchmark = True
    cudnn.enabled = True
    warnings.filterwarnings("ignore")
    ###################################
    cifar_mean = [0.5071, 0.4867, 0.4408]; cifar_std = [0.2675, 0.2565, 0.2761]
    imgnet_mean =[0.485, 0.456, 0.406]; imgnet_std = [0.229, 0.224, 0.225]
    # different dataset for 'im_mean' and 'im_std'
    im_mean, im_std = (cifar_mean, cifar_std) if 'CIFAR' in args.dataset else (imgnet_mean, imgnet_std)
    all_setting = Setting(args, im_mean=im_mean, im_std=im_std, scale=True)
    ReDirectSTD(all_setting.stdout_file, 'stdout', True)
    ###################################
    print('-' * 60)
    print('Recoverable Memory Bank -- All Setting')
    pprint.pprint(all_setting.__dict__)
    print('-' * 60)
    ###################################

    Execrator = YoooP(all_setting)
    if all_setting.phase_train:
        print('-'*30+'CIL Training'+30*'-')
        Execrator.phase_train()
        return
    if all_setting.up2now_test:
        print('-'*30+'CIL Testing'+30*'-')
        Execrator.phase_test()
        return


if __name__ == '__main__':
    mp.set_start_method("spawn")
    sys.setrecursionlimit(1000000)
    all_configs = Configs()
    main(all_configs.args)
