import argparse

from mas_sat import dataset_registry, graph_registry, model_registry, learner_registry

parser = argparse.ArgumentParser(description="MAS-SAT")

# main arguments
# evaluate mode can be multiple modes, comma separated
parser.add_argument("command", type=str, choices=["generate", "train", "evaluate"])
parser.add_argument("mode", type=str)
parser.add_argument("--evaluate-mode", type=str, default=None)

# dataset related
parser.add_argument("--dataset", type=str, choices=dataset_registry.registered_classes(), default="satlib")
parser.add_argument("--set", type=str, default="uf50-218")
parser.add_argument("--split", type=str, choices=["sat", "unsat", "all"], default="sat")
parser.add_argument("--data-dir", type=str, default="./data")
parser.add_argument("--num-train", type=int, default=800)
parser.add_argument("--num-valid", type=int, default=100)
parser.add_argument("--num-test", type=int, default=100)
## SATLIB specific
parser.add_argument("--domain", type=str, default="RND3SAT")

# graph related
parser.add_argument("--graph", type=str, choices=graph_registry.registered_classes(), default="vcgl")

# envirionment related
parser.add_argument("--prop-limit", type=int, default=0)
parser.add_argument("--step-limit", type=int, default=0)
parser.add_argument("--prop-weight", type=float, default=0.005)
parser.add_argument("--step-weight", type=float, default=0.05)
parser.add_argument("--penalty", type=float, default=1)
parser.add_argument("--budget", type=int, default=0)
parser.add_argument("--async-mode", action="store_true")
parser.add_argument("--decide-interval", type=int, default=1)

# agent related
parser.add_argument("--agent", type=str, choices=["solver", "random", "model", "model_influence"], default="model")

# model related
parser.add_argument("--model", type=str, choices=model_registry.registered_classes(), default="modell")
parser.add_argument("--dim", type=int, default=64)
parser.add_argument("--num-step", type=int, default=1)
parser.add_argument("--num-evaluate-step", type=str, default=None)
parser.add_argument("--recurrent", action="store_true")
parser.add_argument("--head", type=str, choices=["heuristic", "assignment", "multi"], default="multi")

# learning related
parser.add_argument("--learner", type=str,choices=learner_registry.registered_classes(), default="reinforce")
parser.add_argument("--gpu", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--grad-clip", type=float, default=1)
parser.add_argument("--grad-clip-norm-type", type=float, default=2)
parser.add_argument("--grad-alpha", type=float, default=0.6)
parser.add_argument("--learn-step", type=int, default=50000)
parser.add_argument("--evaluate-interval", type=int, default=1000)
## RL specific
parser.add_argument("--gamma", type=float, default=0.99)
## buffer specific
parser.add_argument("--buffer-size", type=int, default=20000)
## DQN specific
parser.add_argument("--learn-interval", type=int, default=4)
parser.add_argument("--target-update-interval", type=int, default=10)
parser.add_argument("--eps-init", type=float, default=1.0)
parser.add_argument("--eps-final", type=float, default=0.01)
parser.add_argument("--exploration-step", type=int, default=5000)
parser.add_argument("--eps-decay-step", type=int, default=30000)
## multi specific
parser.add_argument("--heuristic-loss-weight", type=float, default=1)
parser.add_argument("--assignment-loss-weight", type=float, default=10)

# test and recover related
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--finetune", type=str, default=None)

# reproducibility related
parser.add_argument("--seed", type=int, default=42)

# output related
parser.add_argument("--experiment-dir", type=str, default="./experiments/checkpoints")

args = parser.parse_args()
