import argparse
import torch

parser = argparse.ArgumentParser()

### GNN hyperparameters ###
parser.add_argument('--dataset', type=str, default='MUTAG', choices=['MUTAG', 'BA3', 'FC', 'MNIST'])
parser.add_argument('--model', type=str, default='GCN', choices=['GCN', 'GIN'])
parser.add_argument('--hidden', type=int, default=64, help='Number of hidden units.')
parser.add_argument('--nlayers', type=int, default=4, help='Number of hidden layers.')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout ratio.')
parser.add_argument('--pool_type', type=str, default='sum', choices=['mean', 'sum', 'max'], help='Pooling type.')
parser.add_argument('--use_jk', default=True, help='Use Jumping Knowledge.')

### Explainer hyperparameters ###
parser.add_argument('--explainer_name', type=str, default='pgexplainer', help='explainer to be used.')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate.')

parser.add_argument('--round1_lr', type=float, default = 0.0001, help='Phase 2 lr')
parser.add_argument('--round1_epochs', type=int, default = 100, help='Phase 2 epochs')
parser.add_argument('--alpha0', type=float, default=0.10, help='Base quantile cutoff (0, 0.5) for pseudo-labeling')
parser.add_argument('--c', type=float, default=0.50, help='Skewness adjustment factor (>0) for asymmetric quantiles')
parser.add_argument('--w', default=10.0, type=float, help='weight for loss')

parser.add_argument('--second_phase_epochs', type=int, default=100, help='Phase 3 epochs (defaults to round1_epochs)')
parser.add_argument('--second_phase_lr', type=float, default=None, help='Phase 3 learning rate (defaults to round1_lr)')
parser.add_argument('--second_alpha0', type=float, default=None, help='Phase 3 base quantile cutoff (defaults to alpha0)')
parser.add_argument('--second_c', type=float, default=None, help='Phase 3 skewness adjustment factor (defaults to c)')
parser.add_argument('--second_w', type=float, default=None, help='Loss pos_weight for Phase 3 (defaults to w)')

### etc ###
parser.add_argument('--ckpt_path', type=str, default='test/ckpts/', help='Location for saving checkpoints')
parser.add_argument('--gpu', type=int, default=7, help='GPU device id to use')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')

args = parser.parse_args()
