import argparse

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def get_args():
    parser = argparse.ArgumentParser(description='Train a feedforward neural network on MNIST')
    # run_mnist.py
    parser.add_argument('--c', type=float, default=1, help='the max of the uniform distribution')
    parser.add_argument('--weight_distribution', type=str, default="uniform", help='the distribution to sample weights from')
    parser.add_argument('--weight_gain', type=float, default=1)
    parser.add_argument('--bias_distribution', type=str, default="none")
    parser.add_argument('--bias_gain', type=float, default=1)
    parser.add_argument('--num_hidden_layers', type=int, default=1, help='the number of hidden layers in the neural network')
    parser.add_argument('--width', type=int, default=128, help='the width of each layer in the network')
    parser.add_argument('--wandbdir', type=str, default="./wandb", help='where to log the results')
    parser.add_argument('--datadir', type=str, default="./data", help='where to store the data')

    parser.add_argument('--modelpath', type=str, default=None, help='where to store the models')
    parser.add_argument('--wandbmode', type=str, default="online", help='what mode to run wandb in (online, offline, disabled)')
    parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs to train for')
    parser.add_argument('--wandbgroup', type=str, default="default", help='the group to log to in wandb')
    parser.add_argument('--train_weights', type=str2bool, default=False, help='whether to train the weights of the network')
    parser.add_argument('--input_layer_bias', type=str2bool, default=True, help='whether to include a bias in the input layer')
    parser.add_argument('--output_layer_bias', type=str2bool, default=True, help='whether to include a bias in the output layer')
    parser.add_argument('--middle_layers_bias', type=str2bool, default=True, help='whether to include a bias in the middle layers')
    parser.add_argument('--noise_variance', type=float, default=0, help='the variance of the noise to add to the inputs')
    parser.add_argument('--load_weights_path', type=str, default=None, help='path to the weights to load')
    parser.add_argument('--load_biases_path', type=str, default=None, help='path to the biases to load')
    parser.add_argument('--save_weights_path', type=str, default=None, help='path to the weights to save')
    parser.add_argument('--save_biases_path', type=str, default=None, help='path to the biases to save')
    parser.add_argument('--results_path', type=str, default="./results")
    parser.add_argument('--wandb_mode', type=str, default="online", help='what mode to run wandb in (online, offline, disabled)')
    parser.add_argument('--dataset', type=str, default="mnist", help='which dataset to use')
    parser.add_argument('--seed', type=int, default=0, help='the random seed')

    parser.add_argument('--l1_weight', type=float, default=0)
    parser.add_argument('--bias_l1_weight', type=float, default=0)
    parser.add_argument('--bias_l1_baseline', type=float, default=-1)

    parser.add_argument('--finetune_output', type=str2bool, default=False, help='whether to only finetune the output layer')

    
    # drop_units_and_eval.py
    parser.add_argument('--n_threshold_steps', type=int, default=50, help='number of threshold steps to take')
    parser.add_argument("--name",type=str)

    parser.add_argument("--results_dir", type=str, default=None)

    args = parser.parse_args()
    return args