import argparse
import yaml

def alpha_type(value):
    if value in ['trainable', 'estimate']:
        return value
    else:
        return float(value)

def correct_args(args):
   
    if args.model_name != 'transformer_abmil':
        args.transf_att_dim = None
        args.transf_num_heads = None
        args.transf_num_layers = None
        args.transf_use_ff = None
        args.transf_dropout = None

    if args.model_name not in ['abmil', 'transformer_abmil']:
        args.alpha = None
        args.smooth_mode = None
        args.smooth_where = None
        args.pool_att_dim = None
        args.feat_ext_name = None

    if args.alpha in [0.0, None]:
        args.smooth_mode = None
        args.smooth_where = None
        args.use_inst_distances = None
    
    return args

def get_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument('--mode', default='train_test', type=str, help="Mode to run the code (train/test)")
    parser.add_argument('--use_wandb', action='store_true', help="Use wandb or not")
    parser.add_argument('--wandb_project', default='', type=str, help="Wandb project name")
    
    parser.add_argument('--num_workers', default=12, type=int, help="Number of workers to load data")
    parser.add_argument('--pin_memory', action='store_true', help="Pin memory or not")
    parser.add_argument('--distributed', action='store_true', help="Use distributed training")
    parser.add_argument('--test_in_cpu', action='store_true', help="Test in cpu")
    parser.add_argument('--use_sparse', action='store_true', help="Use sparse tensors to store the adjacency matrix")

    # path settings
    parser.add_argument('--history_dir', default='/work/SmoothAttention/history/', type=str, metavar='PATH', help="Path to save the history file")
    parser.add_argument('--weights_dir', default='/work/SmoothAttention/weights/', type=str, metavar='PATH', help="Path to save the model weights")   
    parser.add_argument('--results_dir', default='results/', type=str, metavar='PATH', help="Path to save the results") 

    # experiment settings
    parser.add_argument('--seed', type=int, default=0, help="Seed")
    parser.add_argument('--dataset_name', default='rsna-features_resnet18', type=str, help="Dataset to use")
    parser.add_argument('--batch_size', type=int, default=4, help="Batch size of training")
    parser.add_argument('--val_prop', type=float, default=0.2, help="Proportion of validation data")
    parser.add_argument('--epochs', type=int, default=50, help="Training epochs")
    parser.add_argument('--config_file', type=str, default='/work/SmoothAttention/code/experiments/config.yml', help="Config file to load the settings")

    # model settings
    parser.add_argument('--model_name', type=str, default='abmil', help="Model name")
    parser.add_argument('--feat_ext_name', type=str, default='fc_1_512', help="Name of the CNN model")

    # transformer encoder settings
    parser.add_argument('--transf_att_dim', type=int, default=128, help="Dimension of the key and query in the transformer encoder")
    parser.add_argument('--transf_num_heads', type=int, default=8, help="Number of heads in the transformer encoder")
    parser.add_argument('--transf_num_layers', type=int, default=1, help="Number of layers in the transformer encoder")
    parser.add_argument('--transf_use_ff', action='store_true', help="Use feed forward layer or not in the transformer encoder")
    parser.add_argument('--transf_dropout', type=float, default=0.1, help="Dropout rate in the transformer encoder")
    parser.add_argument('--transf_smooth_steps', type=int, default=0, help="Number of steps to smooth the attention")

    # pooling settings
    parser.add_argument('--pool_att_dim', type=int, default=50, help="Value of attendion dimension")
    parser.add_argument('--use_inst_distances', action='store_true', help="Use instance distances or not to build the adjacency matrix")
    parser.add_argument('--smooth_mode', type=str, default='approx_10', help="Smooth mode for the smooth attention")
    parser.add_argument('--smooth_where', type=str, default='att_values', help="Where to place the smooth layer")
    parser.add_argument('--alpha', type=alpha_type, default=0.0, help="Alpha for the smooth attention")
    parser.add_argument('--spectral_norm', action='store_true', help="Use spectral normalization or not")

    # training settings
    parser.add_argument('--balance_loss', action='store_true', help="Balance the loss using class weights")
    parser.add_argument('--lr', type=float, default=1e-4, help="Initial learning rate")
    parser.add_argument('--patience', type=int, default=10, help="Patience for early stopping")
    parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay for the optimizer")

    args = parser.parse_args()

    if args.config_file is not None:
        with open(args.config_file, 'r') as f:
            config_dict = yaml.safe_load(f)
            ds_config_dict = config_dict[args.dataset_name]
            for key, value in ds_config_dict.items():
                setattr(args, key, value)
    
    args = correct_args(args)    
    
    return args