import os
import sys
import argparse
import torch
from datetime import datetime

from core.utils import set_seed
from core.logging import log_and_print, config_logfile


def parse_args(custom_parse_fn=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-file", type=str)
    parser.add_argument("--trial-codename", type=str)
    parser.add_argument("--log-dir", type=str)
    parser.add_argument("--abort-ckpt", action="store_true")
    parser.add_argument("--asserting-debug", action="store_true")

    parser.add_argument("--n-variables", type=int, default=3)
    parser.add_argument("--rank", type=int, default=3)
    parser.add_argument("--n-width", type=int, default=20)
    parser.add_argument("--n-layers", type=int, default=2)
    parser.add_argument("--optimizer", type=str, default="Adam",
                        choices=["SGD", "Adam"])
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--temperature", type=float, default=1)
    parser.add_argument("--init-var", type=float, default=1e-2)
    parser.add_argument("--soft-logic", default="minmax-prob",
                        choices=["naive-prob", "minmax-prob",
                                 "naive-prob-matmul"])
    parser.add_argument("--final-forall-logic", default="minmax-prob",
                        choices=["naive-prob", "minmax-prob",
                                 "naive-prob-matmul"])
    parser.add_argument("--loss-fn", default="square",
                        choices=["square", "nll"])
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--rl-gamma", type=str, default=0.99)
    parser.add_argument("--sparse-dropout-last", type=int, default=-1)

    parser.add_argument("--seed", type=int, default=0)
    # parser.add_argument("--n-steps", type=int, default=50000)
    parser.add_argument("--n-epochs", type=int, default=10000)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--batch-total-size", type=float, default=1e6)
    parser.add_argument("--num-workers", type=int, default=0)
    parser.add_argument("--relation-directional", action="store_true")
    parser.add_argument("--entropy-regularization", type=float, default=0)
    parser.add_argument("--entropy-reg-increasing", type=str, default="none",
                        choices=["none", "linear", "square"])
    parser.add_argument("--distinct-variables", action="store_true")
    parser.add_argument("--cached-num", type=int, default=100)

    parser.add_argument("--negative", action="store_true")
    parser.add_argument("--negative-range", type=int, nargs=2, default=(0, -1))
    parser.add_argument("--negative-file", type=str)
    parser.add_argument("--negative-index", type=int)
    parser.add_argument("--erase-nodes", type=int, default=0)
    parser.add_argument("--erase-type", type=str)
    parser.add_argument("--all-range", type=int, nargs=2, default=(0, -1))
    parser.add_argument("--temporal-only", action="store_true")
    parser.add_argument("--events-only", action="store_true")
    parser.add_argument("--no-mention-order", action="store_true")
    parser.add_argument("--deduce-attribute", type=str, default=None)
    parser.add_argument("--deduce-relation", type=str, default=None)
    parser.add_argument("--sorting-length-range", type=int, nargs=2,
                        default=(4, 5))
    parser.add_argument("--max-step", type=int, default=10)

    parser.add_argument("--num-cpus", type=int, default=1)
    parser.add_argument("--log-interval", type=int, default=500)
    parser.add_argument("--print-parameters", action="store_true")

    parser.add_argument('--top_k', default=10, type=int)
    if custom_parse_fn is not None:
        custom_parse_fn(parser)
    args = parser.parse_args()

    torch.set_num_threads(args.num_cpus)
    args.device = torch.device("cpu") if (
        args.gpu < 0 or torch.cuda.device_count() == 0) else \
        torch.device(f"cuda:{args.gpu}")
    if args.trial_codename is not None:
        log_dir = os.path.join(args.log_dir, args.trial_codename)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
            print("making directory:", log_dir)
        args.ckpt_file = os.path.join(log_dir, "model.pt")
        args.log_file = os.path.join(log_dir, "log.txt")
        config_logfile(args.log_file)
        args.print = log_and_print
    else:
        print("No codename provided. This trial will not be logged")
        args.ckpt_file = None
        args.log_file = None
        args.print = print

    args.print()
    args.print("=" * 30)
    args.print(' '.join(sys.argv))
    args.print(datetime.now().strftime("%D %H:%M:%S"))
    args.print(args)
    set_seed(args.seed)
    return args
