import argparse
# fmt: off

def add_trainer_args(parser):
    trainer_args = parser.add_argument_group('Trainer args')
    trainer_args.add_argument('--epochs', type=int, default=30)
    trainer_args.add_argument('--approach', type=str, default='regular')
    trainer_args.add_argument('--lr', type=float, default=0.0001)
    trainer_args.add_argument('--weight_decay', type=float, default=0.0001)
    trainer_args.add_argument('--seed', type=int, default=42)
    trainer_args.add_argument('--res_dir', type=str)
    trainer_args.add_argument('--wandb', action=argparse.BooleanOptionalAction)
    trainer_args.add_argument('--early_stopping', action=argparse.BooleanOptionalAction, default=True)
    trainer_args.add_argument('--patience', type=int, default=10)

def add_dataset_args(parser):
    dataset_args = parser.add_argument_group('Dataset args')
    dataset_args.add_argument('--data_dir', type=str, required=True)
    dataset_args.add_argument('--batch_size', type=int, default=128)
    dataset_args.add_argument('--num_workers', type=int, default=4)
    dataset_args.add_argument('--unconf_split', action=argparse.BooleanOptionalAction)
    dataset_args.add_argument('--num_classes', type=int, default=3)

def add_boolean_args(parser):
    boolean_args = parser.add_argument_group('BooleanOptionalAction args')
    boolean_args.add_argument('--save_model', action=argparse.BooleanOptionalAction)

def add_model_args(parser):
    model_args = parser.add_argument_group('model specific args')
    model_args.add_argument('--model', type=str, default='ResNet18')
    model_args.add_argument('--pretrained_model', action=argparse.BooleanOptionalAction)
    model_args.add_argument('--pretrained_path', type=str)

def add_feature_selector_args(parser):
    feature_selector_args = parser.add_argument_group('General Feature Selector (SFW or U-Net) args')
    feature_selector_args.add_argument('--lr_fs', type=float, default=None)
    feature_selector_args.add_argument('--lr_merlin', type=float, default=0.01)
    feature_selector_args.add_argument('--lr_morgana', type=float, default=0.01)
    feature_selector_args.add_argument('--weight_decay_fs', type=float, default=None)
    feature_selector_args.add_argument('--weight_decay_merlin', type=float, default=0.00001)
    feature_selector_args.add_argument('--weight_decay_morgana', type=float, default=0.00001)
    feature_selector_args.add_argument('--gamma', type=float, default=1)
    feature_selector_args.add_argument('--l1_penalty_coefficient', type=float, default=1)
    feature_selector_args.add_argument('--l2_penalty_coefficient', type=float, default=0)
    feature_selector_args.add_argument('--tv_penalty_coefficient', type=float, default=0)
    feature_selector_args.add_argument('--mask_size', type=int, default=1500)
    feature_selector_args.add_argument('--sfw_max_iterations', type=int, default=350)
    feature_selector_args.add_argument('--sfw_patience', type=int, default=10)
    feature_selector_args.add_argument('--unet_steps', type=int, default=1)