import argparse
import yaml


# load config file(type:yaml)
config_parser = parser = argparse.ArgumentParser(
    description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='./config/config.yaml', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch MNIST Training')
# general config
parser.add_argument('--device_rank', default='0',
                    help='path to dataset')
parser.add_argument('--t', type=float, default=1,
                    help='parameter initialization distribution variance power(We first assume that each layer is the same width.)')
parser.add_argument('--training_batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--test_batch_size', type=int, default=32,
                    help='input batch size for test (default: 32)')
parser.add_argument('--lr',  type=float, default=1e-3,
                    help='learning rate for training (default: 1e-3)')
parser.add_argument('--ini_output_dir', type=str,
                    default='/home/xxx/data/saddle_points/test104/')
parser.add_argument('--input_dim',   default=1, type=int,
                    help='the input dimension for model (default: 1)')
parser.add_argument('--output_dim',   default=1, type=int,
                    help='the output dimension for model (default: 1)')
parser.add_argument('--training_size',   default=1000, type=int,
                    help='the training size for model (default: 1000)')
parser.add_argument('--test_size',   default=10000, type=int,
                    help='the test size for model (default: 10000)')
parser.add_argument('--plot_epoch',   default=1000, type=int,
                    help='step size of plotting interval (default: 1000)')
parser.add_argument('--save_epoch',   default=1000, type=int,
                    help='step size of saving interval (default: 1000)')
parser.add_argument('--training_steps',   default=100001, type=int,
                    help='the number of training steps (default: 100001)')
parser.add_argument('--stop_loss',   default=1e-5, type=float,
                    help='When the loss is less than the stop loss, the training ends. (default: 1e-5)')
parser.add_argument('--one_hot',   default=False, type=bool,
                    help='need one hot or not (default: False)')
parser.add_argument('--softmax',   default=False, type=bool,
                    help='need softmax or not (default: False)')
parser.add_argument('--data',   default='1Dpro', type=str,
                    help='datasets type (default: 1Dpro)')
parser.add_argument('--hidden_layers_width',
                    nargs='+', type=int, default=[100])
parser.add_argument('--cal_deri_epoch',   default=1000, type=int,
                    help='step size of calculating derivatives and hessian metrix (default: 1000)')
parser.add_argument('--save_dir',
                    nargs='+', type=str, default=[])
parser.add_argument('--dropout',   default=False, type=bool,
                    help='need dropout or not(default:False)')
parser.add_argument('--dropout_pro',   default=0.2, type=float,
                    help='dropout proportion(default:0.2)')

parser.add_argument('--change_dropout_pro',   default=False, type=bool,
                    help='change dropout proportion or not')
            
parser.add_argument('--spend_time',   default=10, type=int,
                    help='the number of epochs after exploration (default:10)')
parser.add_argument('--change_lr',   default=False, type=bool,
                    help='need changing lr or not(default:False)')
parser.add_argument('--changed_lr',   default=1e-5, type=float,
                    help='changed learning rate(default:1e-5)')
parser.add_argument('--turblence',   default=1e-8, type=float,
                    help='std of network parameter disturbance(default:1e-8)')

parser.add_argument('--cal_pca',   default=False, type=bool,
                    help='need calculate and save pca matrix or not')
parser.add_argument('--add_tru_on_grad',   default=False, type=bool,
                    help='add disturbance on the parameters grad or not')

parser.add_argument('--add_tru_on_weight',   default=False, type=bool,
                    help='add disturbance on the parameters or not')

parser.add_argument('--no_training',   default=False, type=bool,
                    help='calculate pca matrix with training weights or not')

parser.add_argument('--bias',   default=True, type=bool,
                    help='need bias or not')

parser.add_argument('--method',   default='pca', type=str,
                    help='hessian or pca or wulei')

parser.add_argument('--use_nesterov',   default=False, type=bool,
                    help='use nesterov optimizer or not')


# 1Dpro config
parser.add_argument('--data_boundary', nargs='+', type=int, default=[-1, 1],
                    help='the boundary of 1D data')
parser.add_argument('--Sampling times', default=1000, type=int,
                    help='times for sampling ')
parser.add_argument('--plot_output', default=True, type=bool,
                    help='times for dropout')


def parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text
