import argparse

parser = argparse.ArgumentParser()

# for GIN
# Training settings
# Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
parser = argparse.ArgumentParser(
    description='PyTorch graph convolutional neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="MUTAG",
                    help='name of dataset (default: MUTAG)')
parser.add_argument('--batch_size', type=int, default=16,
                    help='input batch size for training (default: 32)')
parser.add_argument('--iters_per_epoch', type=int, default=50,
                    help='number of iterations per each epoch (default: 50)')
parser.add_argument('--epochs', type=int, default=350,
                    help='number of epochs to train (default: 350)')
parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=0,
                    help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--fold_idx', type=int, default=0,
                    help='the index of fold in 10-fold validation. Should be less then 10.')
parser.add_argument('--num_layers', type=int, default=5,
                    help='number of layers INCLUDING the input one (default: 5)')
parser.add_argument('--num_mlp_layers', type=int, default=2,
                    help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
parser.add_argument('--hidden_dim', type=int, default=64,
                    help='number of hidden units (default: 64)')
parser.add_argument('--final_dropout', type=float, default=0.5,
                    help='final layer dropout (default: 0.5)')
parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],
                    help='Pooling for over nodes in a graph: sum or average')
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
                    help='Pooling for over neighboring nodes: sum, average or max')
parser.add_argument('--learn_eps', action="store_true",
                    help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
parser.add_argument('--degree_as_tag', action="store_true",
                    help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
parser.add_argument('--filename', type=str, default="",
                    help='output file')

parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--aug_type', type=str, default='diff',help='random or diff or TODO')
parser.add_argument('--drop_prob', type=float, default=0.5)
parser.add_argument('--projection_hidden_size', type=int, default=512)#2048
parser.add_argument('--projection_size', type=int, default=256)
parser.add_argument('--num_layer', type=int, default=4)
#parser.add_argument('--hid_dim', type=int, default=256)

# add negative samples
parser.add_argument('--use-neg-loss', dest='use_neg_loss', action='store_const', const=True, default=False)
parser.add_argument('--neg_times', type=int, default=0)
parser.add_argument('--alpha', type=float, default=0.3)
parser.add_argument('--beta', type=float, default=1.0)

# NOTE: below is for semi-supervised graph classification
parser.add_argument('--use-unsup-loss', dest='use_unsup_loss', action='store_const', const=True, default=False)
parser.add_argument('--use-supcon-loss', dest='use_supcon_loss', action='store_const', const=True, default=False)
parser.add_argument('--start_supcon_epoch', type=int, default=30, help='number of epochs for starting using supervised contrastive loss')

parser.add_argument('--use-selftrain', dest='use_selftrain', action='store_const', const=True, default=False)
parser.add_argument('--selftrain_iter', type=int, default=20, help='self training iterations')
parser.add_argument('--num_samples_per_iter', type=int, default=10, help='number of samples per self training iteration')
parser.add_argument('--start_selftrain_epoch', type=int, default=30, help='number of epochs for starting self training')
parser.add_argument('--selftrain_threshold', type=int, default=0.95, help='number of samples per self training iteration')

# temperature
parser.add_argument('--temp', type=float, default=0.07, help='temperature for loss function')
parser.add_argument('--test_ratio', type=float, default=0.1)
#parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--train_ratio', type=float, default=0.1, help='ratio of labeled training set (in QM9, 0.05)')
parser.add_argument('--unsup_ratio', type=float, default=0.8)

args = parser.parse_args()

#if args.unsup_ratio != 1-args.test_ratio-args.valid_ratio:
#if args.train_ratio + args.test_ratio != args.unsup_ratio != 1-args.test_ratio-args.valid_ratio:
#    print("Warning: the train/val/test ratios are not specified properly")
