import os
from argparse import ArgumentParser

def arg_parser():
  parser = ArgumentParser(description='Train a Pytorch-Lightning diffusion model on a TSP dataset.')
  parser.add_argument('--task', type=str, required=True)
  parser.add_argument('--storage_path', type=str, required=True)
  parser.add_argument('--training_split', type=str, default='data/tsp/tsp50_train_debug.txt')
  parser.add_argument('--training_split_label_dir', type=str, default=None,
                      help="Directory containing labels for training split (used for MIS).")
  parser.add_argument('--validation_split', type=str, default='data/tsp/tsp50_test_concorde.txt')
  parser.add_argument('--test_split', type=str, default='data/tsp/tsp50_test_concorde.txt')
  parser.add_argument('--predict_split', type=str, default='data/tsp/tsp50_test_concorde.txt')
  parser.add_argument('--validation_examples', type=int, default=64)
  parser.add_argument('--prediction_examples', type=int, default=128)

  parser.add_argument('--batch_size', type=int, default=64)
  parser.add_argument('--num_epochs', type=int, default=50)
  parser.add_argument('--learning_rate', type=float, default=1e-4)
  parser.add_argument('--weight_decay', type=float, default=0.0)
  parser.add_argument('--lr_scheduler', type=str, default='constant')

  parser.add_argument('--num_workers', type=int, default=16)
  parser.add_argument('--fp16', action='store_true')
  parser.add_argument('--use_activation_checkpoint', action='store_true')

  parser.add_argument('--diffusion_type', type=str, default='gaussian')
  parser.add_argument('--diffusion_schedule', type=str, default='linear')
  parser.add_argument('--diffusion_steps', type=int, default=1000)
  parser.add_argument('--inference_diffusion_steps', type=int, default=1000)
  parser.add_argument('--inference_schedule', type=str, default='linear')
  parser.add_argument('--sequential_sampling', type=int, default=1)
  parser.add_argument('--parallel_sampling', type=int, default=1)

  parser.add_argument('--n_layers', type=int, default=12)
  parser.add_argument('--hidden_dim', type=int, default=256)
  parser.add_argument('--sparse_factor', type=int, default=-1)
  parser.add_argument('--aggregation', type=str, default='sum')
  parser.add_argument('--two_opt_iterations', type=int, default=1000)
  parser.add_argument('--save_numpy_heatmap', action='store_true')
  parser.add_argument("--norm_scheme", type=str, default="layer", help="Normalization scheme")
  parser.add_argument("--simnorm_dim", type=int, default=-1)

  parser.add_argument('--logger_name', type=str, default=None)
  parser.add_argument('--ckpt_path', type=str, default=None)
  parser.add_argument('--resume_weight_only', action='store_true')

  parser.add_argument('--do_train', action='store_true')
  parser.add_argument('--do_test', action='store_true')
  parser.add_argument('--do_valid_only', action='store_true')
  parser.add_argument('--do_predict', action='store_true')
  parser.add_argument("--fusion_method", type=str, default=None)
  parser.add_argument("--min_visited", type=int, default=1)
  parser.add_argument("--sub_graph_size", type=int, default=50)
  parser.add_argument("--process_num", type=int, default=1)

  parser.add_argument("--eps", type=float, default=0.001)
  parser.add_argument("--use_multi_processing", action='store_true')
  parser.add_argument("--exhaustive", action="store_true")
  parser.add_argument("--num_trials", type=int, default=100)
  parser.add_argument("--multip_batchsize", type=int, default=0)

  args = parser.parse_args()

  if args.task == "tsp":
    args.diffusion_type = "gaussian"
  else:
    args.diffusion_type = "categorical"


  if args.norm_scheme == "simnorm" and args.simnorm_dim < 0:
    raise ValueError("simnorm dim cannot be negative")
  return args

def save_args(args, exp_name):
  targetDir = os.path.join(args.storage_path, f'models', exp_name)
  if not os.path.exists(targetDir):
    os.makedirs(targetDir)
  with open(os.path.join(targetDir, "args.txt"), "w") as f_args:
    for k, v in args.__dict__.items():
      f_args.write("%s:%s\n"%(k, v))