import argparse


def parse_args():
    parser = argparse.ArgumentParser(description="PyTorch Training")

    # primary
    parser.add_argument(
        "--configs", type=str, default=None, help="configs file",
    )
    
    parser.add_argument(
        "--result-dir",
        default="./trained_models",
        type=str,
        help="directory to save results",
    )
    parser.add_argument(
        "--exp-name",
        type=str,
        help="Name of the experiment (creates dir with this name in --result-dir)",
    )
    parser.add_argument(
        '--cvae_num_epoch',
        type=int,
        default=800,
        help='epoch of trainig'
    )
    
    parser.add_argument(
        '--mode',
        type=str,
        choices=['predefine', 'train'],
        default='predefine'
    )
    
    parser.add_argument(
        '--seed',
        type=int,
        default=110,
        help='random seed'
    )
    
    parser.add_argument(
        '--gtseed',
        type=int,
        default=117,
        help='random seed for gt GNN'
    )
    
    parser.add_argument(
        '--protect',
        type=int,
        default=1,
        help='disadvantaged group'
    )
    
    parser.add_argument(
        '--lf',
        type=float,
    )
    
    # gcn setting
    parser.add_argument(
        '--A',
        type=int,
        default=5,
        help='dimension of A'
    )
    parser.add_argument(
        '--hidden_channels',
        type=int,
        default=8,
        help='dimension of each hidden sample'
    )
    
    # cvae setting
    parser.add_argument(
        '--latent_dim',
        type=int,
        default=16,
        help='dimension of latent space of cvae'
    )
    parser.add_argument(
        '--num_layers',
        type=int,
        default=1,
        help='number of hidden layers of GNN'
    )
    
    
    # generate dataset
    parser.add_argument(
        '--gen_A',
        type=int,
        default=1,
        help='dimension of A in synthetic dataset'
    )
    
    parser.add_argument(
        '--k',
        type=int,
        default=101,
        help='k-neighbors for knn'
    )
    
    parser.add_argument(
        '--coff_A',
        type=float,
        default=1.0
    )
    parser.add_argument(
        '--scale',
        type=float,
        default=0.1,
        help='noise scale for generated data'
    )
    
    parser.add_argument(
        '--mapping_function',
        type=str,
        default='LinearMapping',
        help='Mapping function from xs -> y'
    )
    # dataset
    parser.add_argument(
        '--dataset',
        type=str,
        default='credit',
        help='name of the dataset'
    )
    parser.add_argument(
        '--input_size',
        type=int,
        default=31,
        help='dimension of feature'
    )
    parser.add_argument(
        '--sensitive_size',
        type=int,
        default=1,
        help='dimension of sensitive attribute'
    )
    
    parser.add_argument(
        '--z_size',
        type=int, 
        default=31
    )
    
    parser.add_argument(
        '--num_class',
        type=int,
        default=1
    )
    
    
    parser.add_argument(
        '--root',
        type=str,
        default='data/credit',
        help='directory of data root'
    )
    parser.add_argument(
        '--filename',
        type=str,
        default='UCI_Credit_Card.csv',
        help='name of data file'
    )
    
    
    return parser.parse_args()