from argparse import ArgumentParser

def add_model_args(parent_parser: ArgumentParser):
    parser = parent_parser.add_argument_group('transformer model args')

    # model args
    parser.add_argument('--d_model', type=int, default=256)
    parser.add_argument('--nhead', type=int, default=8)
    parser.add_argument('--dim_feedforward', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--num_layers', type=int, default=6)
    parser.add_argument('--pe_type', type=str, default='learnable')
    parser.add_argument('--pe_scale_factor', type=float, default=1.0)

    # distillation args
    parser.add_argument('--distillation', action='store_true', default=False)
    parser.add_argument('--online_distillation', action='store_true', default=False)
    parser.add_argument('--teacher_feature_distillation_layers', nargs='*')
    parser.add_argument('--student_feature_distillation_layers', nargs='*')
    parser.add_argument('--teacher_attention_weight_distillation_layers', nargs='*')
    parser.add_argument('--student_attention_weight_distillation_layers', nargs='*')
    parser.add_argument('--feature_distillation_loss_weight', type=float, default=0.0)
    parser.add_argument('--attention_weight_distillation_loss_weight', type=float, default=0.0)
    parser.add_argument('--warmup_epochs', type=int, default=-1)
    parser.add_argument('--warmup_task_loss_weight', type=float, default=0.0)
    parser.add_argument('--warmup_feature_distillation_loss_weight', type=float, default=0.0)
    parser.add_argument('--warmup_attention_weight_distillation_loss_weight', type=float, default=0.0)
    parser.add_argument('--teacher_checkpoint_save_path', type=str)

    # finetune args
    parser.add_argument('--finetune', action='store_true', default=False)
    parser.add_argument('--finetune_checkpoint_save_path', type=str)
    parser.add_argument('--change_dropout', type=float, default=10.0)
    parser.add_argument('--freeze_layers', nargs='*')

    # training args
    parser.add_argument('--validate', action='store_true', default=False)
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--save_top_k', type=int, default=20)
    parser.add_argument('--learning_rate', type=float, default=3e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-2)

    # wandb
    parser.add_argument('--wandb_offline', action='store_true', default=False)
    return parent_parser

def add_data_args(parent_parser: ArgumentParser):
    parser = parent_parser.add_argument_group('datamodule args')
    parser.add_argument('--dataset_name', type=str)
    parser.add_argument('--num_workers', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--multi_hop_max_dist', type=int, default=5)
    parser.add_argument('--spatial_pos_max', type=int, default=1024)

    return parent_parser

def to_int_list(arg_list):
    if arg_list is None:
        return None
    else:
        return list(map(int, arg_list))