import argparse
# fmt: off

def add_trainer_args(parser):
    trainer_args = parser.add_argument_group('Trainer args')
    trainer_args.add_argument("-e", "--epochs", type=int, help="Number of Epochs", required=True)
    trainer_args.add_argument("--lr", type=float, default=0.0001, help="Outer learning rate of model")
    trainer_args.add_argument("--approach", type=str, help="Approach (e.g., `regular`, `sfw`, `learn_fs`)", required=True)
    trainer_args.add_argument("--seed", type=int, help="Seed for Reproducibility", default=42)
    trainer_args.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Log to wandb")
    trainer_args.add_argument("--res_dir", type=str, default='checkpoints', help="Path to the directory to save the results")
    trainer_args.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay for optimizer')
    trainer_args.add_argument("--early_stopping", action=argparse.BooleanOptionalAction, help="Enable early stopping", default=True)
    trainer_args.add_argument("--patience", type=int, help="Early stopping patience", default=10)

def add_dataset_args(parser):
    dataset_args = parser.add_argument_group('Dataset args')
    dataset_args.add_argument('--data_dir', type=str, help='Path to the directory containing either the encodings or the images')
    dataset_args.add_argument('--enc_type', type=str, help='Type of encodings to use', default='one_hot_padded')
    dataset_args.add_argument('--batch_size', type=int, default=128) 
    dataset_args.add_argument('--num_workers', type=int, default=4)
    dataset_args.add_argument('--unconf_split', action=argparse.BooleanOptionalAction, 
                              help='Whether to split the train set into val and train set, val set becomes test set. Set to False if you want the original confounded dataset.')
    dataset_args.add_argument('--partial_conf_ratio', type=float, default=0.0,
                            help='Percentage of train set size to remove from test set and add to train set (e.g., 0.05 means 5%% of train size)')
    dataset_args.add_argument('--partial_conf_dir', type=str,
                            help='Path to the directory containing additional test samples (no confounder)')
    

def add_boolean_args(parser):
    boolean_args = parser.add_argument_group('BooleanOptionalAction args')
    boolean_args.add_argument("--save_model", action=argparse.BooleanOptionalAction, help="Save model")
    boolean_args.add_argument("--save_confusion_matrix", action=argparse.BooleanOptionalAction, help="Save confusion matrix plots as SVG files")

def add_model_args(parser):
    model_args = parser.add_argument_group('model specific args')
    model_args.add_argument("--model", type=str, help="Model for Arthur, e.g., SetTransformer, ResNet18 etc.")
    model_args.add_argument("--pretrained_model", action=argparse.BooleanOptionalAction, help="Use pretrained model")
    model_args.add_argument("--pretrained_path", type=str, help="Path to pretrained model")
    model_args.add_argument('--n_heads', type=int, default=4, help='Number of heads for SetTransformer')
    model_args.add_argument('--set_transf_hidden', type=int, default=128, help='Hidden size for SetTransformer')
    model_args.add_argument("--hidden_dim", type=int, help="Hidden Dimension for MLP Classifier", default=512)
    model_args.add_argument("--dropout", type=float, help="Dropout for MLP Classifier", default=0.3)

def add_feature_selector_args(parser):
    feature_selector_args = parser.add_argument_group('General Feature Selector (SFW or U-Net) args')
    feature_selector_args.add_argument("--segmentation_method", type=str, help="Segmentation method for Merlin and Morgana (only topk atm)")
    feature_selector_args.add_argument("--mask_size", type=int, help="Size of Mask")
    feature_selector_args.add_argument("--lr_merlin", type=float, help="Learning Rate of Merlin either as NN or SFW Optimizer")
    feature_selector_args.add_argument("--lr_morgana", type=float, help="Learning Rate of Morgana either as NN or SFW Optimizer")
    feature_selector_args.add_argument("--lr_fs", type=float, default=None, help="Learning Rate of Feature Selector if both Merlin and Morgana should have the same learning rate")
    feature_selector_args.add_argument("--gamma", type=float, help="Gamma for weighting the loss between Merlin and Morgana")
    feature_selector_args.add_argument("--l1_penalty_coefficient", type=float, help="L1 penalty coefficient for SFW")
    feature_selector_args.add_argument("--sfw_max_iterations", default=350, type=int, help="Max iterations for SFW")
    feature_selector_args.add_argument("--sfw_patience", default=10, type=int, help="Patience for SFW")
    feature_selector_args.add_argument("--fs_model", type=str, help="Feature Selector Model", default="settransformer")
    feature_selector_args.add_argument("--fs_hidden_dim", type=int, help="Hidden Dimension for MLP or SetTransformer Feature Selector", default=128)
    feature_selector_args.add_argument("--fs_dropout", type=float, help="Dropout for MLP or SetTransformer Feature Selector", default=0.1)
    feature_selector_args.add_argument("--fs_n_heads", type=int, help="Number of heads for SetTransformer Feature Selector", default=4)
    feature_selector_args.add_argument("--weight_decay_merlin", type=float, help="Weight decay for Merlin", default=0.0001)
    feature_selector_args.add_argument("--weight_decay_morgana", type=float, help="Weight decay for Morgana", default=0.0001)
    feature_selector_args.add_argument("--weight_decay_fs", type=float, default=None, help="Weight decay for Feature Selector if both Merlin and Morgana should have the same weight decay")
    feature_selector_args.add_argument("--feature_distribution", action=argparse.BooleanOptionalAction, help="Compute feature distribution", default=True)
    feature_selector_args.add_argument("--feat_interp_ncb_s0", action=argparse.BooleanOptionalAction, help="Use feature interpretations for pretrained NCB model with seed 0 in histogram plots", default=True)
    feature_selector_args.add_argument("--compute_prec_and_ent", action=argparse.BooleanOptionalAction, help="Compute precision and entropy", default=False) 
    feature_selector_args.add_argument("--compute_avg_occ", action=argparse.BooleanOptionalAction, help="Compute average occurrence", default=False) 