import argparse
import problems.problem_base

def parse_boolean(value):
    value = value.lower()

    if value in ["true", "yes", "y", "1", "t"]:
        return True
    elif value in ["false", "no", "n", "0", "f"]:
        return False

    return False

def get_options():
    parser = argparse.ArgumentParser(description='RL running machine')
    parser.add_argument('--problem', default='MVC', help="The problem to solve, currently supports GC and MVC")

    parser.add_argument('--graph_types', metavar='GRAPH',
                        default=None,
                        help='List of graph types to optimize. Each element of the list is a tuple '
                             'containing the name of the graph type and a dictionary of parameters.')
    parser.add_argument('--train_graph_nodes', nargs="*", type=int, default=[20, 50],
                        help='Number of nodes for training. Each size in the list is used in equal proportion.')
    parser.add_argument('--val_graph_nodes', nargs="*", type=int, default=[50],
                        help='Number of nodes for validation. Each size in the list is used in equal proportion.')

    parser.add_argument('--val_samples', type=int, default=2000, help='number of graphs in validation dataset')
    parser.add_argument('--train_samples', type=int, default=5000,
                        help='number of training graphs to generate for EACH size in train_graph_nodes')

    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--n_epochs', type=int, metavar='n', default='100', help='number of epochs')
    parser.add_argument('--seed', nargs="*", type=int, default=[1313], help='Random seed(s) to use')

    parser.add_argument('--initialization', default='degree')
    parser.add_argument('--throwback', type=int, default=1, help='number of previous nodes to include in context')
    parser.add_argument('--decoding_type', default='local', help='local, static, or global')
    parser.add_argument('--embed_dim', type=int, default=64, help='dimension of the encoder embeddings')
    parser.add_argument('--n_encoder_layers', type=int, default=3, help='convolution depth for graph embedding')
    parser.add_argument('--n_heads', type=int, default=4, help='number of attention heads')
    parser.add_argument('--normalize', type=parse_boolean, default=True, help='weather to use batch normalization in the encoder')
    parser.add_argument('--shortcuts', type=parse_boolean, default=True, help='weather to use shortcuts in the encoder')

    parser.add_argument('--train_file', default='None', type=str,
                        help='loads data from pickle file if not None')
    parser.add_argument('--test_file', default='None', type=str,
                        help='loads data from pickle file for testing if not None')
    parser.add_argument('--ckpt', default='./checkpoints', type=str,
                        help='directory to save results in, makes a new one using current time if None')
    parser.add_argument('--load_model_path', type=str, default=None,
                        help='when not none loads model from the given checkpoint path')
    parser.add_argument('--checkpoint_last_epochs', type=int, default=10, help='checkpoint the best models in the last'
                                                                               'this many epochs'),
    parser.add_argument('--log_step', type=int, default=100, help='Log info every log_step steps'),

    parser.add_argument('--lr', type=float, default=1e-4, help="Set the learning rate for the actor network")
    parser.add_argument('--max_grad_norm', type=float, default=1.0,
                        help='Maximum L2 norm for gradient clipping, default 1.0 (0 to disable clipping)')

    parser.add_argument('--neptune_api_token', default=None, help='Neptune.ai api token')

    opts = parser.parse_args()

    # WARNING: This is dangerous
    if opts.graph_types is not None:
        opts.graph_types = eval(opts.graph_types)

    if opts.problem is not None:
        opts.problem = problems.problem_base.load_problem(opts.problem)

    return opts
