
from argparse import ArgumentParser
from utils.utils import str2bool

def parse_args():
    parser = get_parser_for_basic_args()
    parser = _add_training_args(parser)
    parser = _add_model_args(parser)
    parser = _add_dataset_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_logging_args(parser)
    parser = _add_env_args(parser)
    '''
    parser = _add_finetune_args(parser)
    '''
    args = parser.parse_args()

    if args.weight_decay_incr_style == "constant":
        assert args.start_weight_decay == args.end_weight_decay
    else:
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None

    if args.save_interval == None:
        args.save_interval = args.eval_interval

    return args

def get_parser_for_basic_args():
    parser = ArgumentParser("Basic Configuration")

    parser.add_argument(
        "--model",
        type=str,
        choices=["transformer_xl", "unoffical_gato"],
        default="transformer_xl",
        help="Choose the language model to use.",
    )

    parser.add_argument("--load-dir", type=str, help="Path of checkpoint to load.")
    parser.add_argument("--tokenizer-ver", type=str, default="v1")

    # Scalar tokenizer
    parser.add_argument("--num-discrete-values", type=int, default=1024)

    # discretization
    parser.add_argument("--num-continuous-bin", type=int, default=1024)
    parser.add_argument("--discretize-mu", type=float, default=100.0)
    parser.add_argument("--discretize-M", type=float, default=256.0)

    return parser

def _add_env_args(parser):
    group = parser.add_argument_group(title="environment")
    group.add_argument(
        "--bp-max-item-num",
        type=int,
        default=7,
        help="Max item quantity of 01BP env",
    )
    group.add_argument(
        "--tsp-city-num",
        type=int,
        default=10,
        help="Node num of TSP env",
    )
    group.add_argument(
        "--pctsp-node-num",
        type=int,
        default=20,
        help="Node num of PCTSP env",
    )
    group.add_argument(
        "--op-node-num",
        type=int,
        default=20,
        help="Node num of OP env",
    )
    group.add_argument(
        "--cvrp-node-num",
        type=int,
        default=20,
        help="Node num of CVRP env",
    )

    group.add_argument(
        "--data-num-tsp", 
        type=int, 
        default=0,
        help="The Num of data used for training. "
        "The whole dataset will be split into train-set and prompt-set"
        "and the train-set will be further split into train-set and vaild-set"
        "Set 0 to use the entire dataset"
    )
    group.add_argument(
        "--data-num-01bp", 
        type=int, 
        default=0,
        help="The Num of data used for training. "
        "The whole dataset will be split into train-set and prompt-set"
        "and the train-set will be further split into train-set and vaild-set"
        "Set 0 to use the entire dataset"
    )
    group.add_argument(
        "--data-num-pctsp", 
        type=int, 
        default=0,
        help="The Num of data used for training. "
        "The whole dataset will be split into train-set and prompt-set"
        "and the train-set will be further split into train-set and vaild-set"
        "Set 0 to use the entire dataset"
    )
    group.add_argument(
        "--data-num-op", 
        type=int, 
        default=0,
        help="The Num of data used for training. "
        "The whole dataset will be split into train-set and prompt-set"
        "and the train-set will be further split into train-set and vaild-set"
        "Set 0 to use the entire dataset"
    )
    group.add_argument(
        "--data-num-cvrp", 
        type=int, 
        default=0,
        help="The Num of data used for training. "
        "The whole dataset will be split into train-set and prompt-set"
        "and the train-set will be further split into train-set and vaild-set"
        "Set 0 to use the entire dataset"
    )

    return parser

def _add_model_args(parser):
    # language model
    parser.add_argument(
        '--linear-emb-items', 
        nargs='+', 
        default={}, 
        help='Each item here corresponding to a linear layer for embedding, and the "item_name" items are the obs item name of MDP episode data,'
        'all the corresponding data will be directly embed by the linear layer like DT instead of tokenize & embedding like normal GPT model'
    )
    parser.add_argument(
        "--n-embed",
        type=int,
        default=768,
        help="Vocabulary size of the GPT-2 model. Defines the "
        "number of different tokens that can be represented "
        "by the`inputs_ids` passed when calling",
    )
    parser.add_argument(
        "--n-position",
        type=int,
        default=1024,
        help="The maximum sequence length that this model "
        "might ever be used with. Typically set this to "
        "something large just in case (e.g., 512 or 1024 "
        "or 2048)",
    )
    parser.add_argument(
        "--n-layer",
        type=int,
        default=12,
        help="Number of hidden layers in the Transformer encoder",
    )
    parser.add_argument(
        "--n-head",
        type=int,
        default=12,
        help="Number of attention heads for each attention layer in "
        "the Transformer encoder.",
    )
    parser.add_argument(
        "--n-inner",
        type=int,
        default=None,
        help="Dimensionality of the inner feed-forward layers. "
        "`None` will set it to 4 times n_embd",
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="gelu",
        help="Activation function, to be selected in the list"
        "['relu', 'gelu', 'tanh', 'sigmoid', 'geglu']",
    )
    parser.add_argument(
        "--layer-norm-epsilon",
        type=float,
        default=1e-5,
        help="The epsilon to use in the layer normalization layers.",
    )
    parser.add_argument(
        "--share-input-output-embedding", 
        type=str2bool, 
        default=False,
        help="Whether to share embedding weights between input and output layer"
    )
    parser.add_argument("--fp16", type=str2bool, default=True)

    # TransformerXL args
    parser.add_argument(
        "--mem-len",
        type=int,
        default=None,
        help="the memory length used during evaluation",
    )
    parser.add_argument(
        "--use-mem",
        type=str2bool,
        default=True,
        help="enable prompt or not. (0: off, 1: on)",
    )
    parser.add_argument(
        "--pre-lnorm", 
        type=str2bool, 
        default=True, 
        help="The position of layernorm"
    )
    parser.add_argument(
        "--same-length",
        type=str2bool,
        default=True,
        help="Whether to use same context length in attention masks",
    )
    parser.add_argument(
        "--untie-r",
        type=str2bool,
        default=False,
        help="Whether to use disjoined relative positional vector u, v",
    )
    parser.add_argument(
        "--drop",
        type=float,
        default=0.1,
        help="Dropout of embeddings and ffn in transformer XL",
    )
    parser.add_argument(
        "--dropattn", 
        type=float, 
        default=0.0, 
        help="Dropout of attention output"
    )
    parser.add_argument(
        "--embd-pdrop",
        type=float,
        default=0.1,
        help="The dropout ratio for the attention.",
    )
    parser.add_argument(
        "--use-deepnorm", 
        type=str2bool,
        default=False,
        help="Whether to use DeepNorm described in https://arxiv.org/pdf/2203.00555.pdf"
    )
    
    return parser

def _add_training_args(parser):
    group = parser.add_argument_group(title="training")
    group.add_argument(
        "--snapshot-save-interval", 
        type=int, 
        default=5,
        help="Interval between two snapshots saved when using DDP"
    )
    group.add_argument(
        "--use-early-stopping", 
        type=str2bool, 
        default=False,
        help="Whether to use the early stopping mechanism"
    )
    group.add_argument(
        "--early-stopping-patience", 
        type=int, 
        default=7,
        help="How long to wait after last time validation loss improved."
    )
    group.add_argument(
        "--early-stopping-delta", 
        type=int, 
        default=0,
        help="Minimum change in the monitored quantity to qualify as an improvement."
    )
    group.add_argument(
        "--use-amp", 
        type=str2bool, 
        default=False,
        help="Whether use torch.cuda.amp speeding up training or not"
    )
    group.add_argument(
        "--dataset-distribution",
        type=str,
        default=None,
        help="the node distribution of dataset"
    )
    group.add_argument(
        "--dataset-weights",
        type=str,
        default=None,
        help="the sample weights for each dataset when the model trained on a mixed dataset"
    )
    group.add_argument(
        "--batch-size",
        type=int,
        default=64,
        help="Batch size per model instance."
    )
    group.add_argument(
        "--batch-grad-accum-step",
        type=int,
        default=1,
        help="Batch step of Gradient Accumulation."
        "The equivalent batch_size is batch_size * batch_grad_accum_step"
    )
    group.add_argument(
        "--batch-num",
        type=int,
        default=500,
        help="Batch num per training epoch."
    )
    group.add_argument(
        "--train-iters",
        type=int,
        default=100,
        help="Total number of iterations to train over all "
        "training runs. Note that either train-iters or "
        "train-samples should be provided.",
    )
    group.add_argument(
        "--dataloader-type",
        type=str,
        default="random",
        choices=["sequential", "random", "DDP"],
        help="Fetch data sequentially or out-of-order",
    )
    group.add_argument(
        "--optimizer",
        type=str,
        default="adam",
        choices=["adam", "sgd", "adamw"],
        help="Optimizer function",
    )
    group.add_argument(
        "--num-workers", 
        type=int, 
        default=0, 
        help="Dataloader number of workers."
    )
    group.add_argument(
        "--lr-decay-style",
        type=str,
        default="linear",
        choices=["constant", "linear", "cosine"],
        help="Learning rate decay function.",
    )
    group.add_argument(
        "--lr-decay-factor",
        type=float,
        default=20,
        help="The final stable learning rate = (lr_max/lr_decay_factor)",
    )
    group.add_argument(
        "--lr-warmup-ratio",
        type=float,
        default=0.1,
        help="the proportion of the number of warmup iterations to the total number of iterations",
    )
    group.add_argument(
        "--lr-decay-ratio",
        type=float,
        default=0.8,
        help="the proportion of the number of warmup iterations to the total number of iterations",
    )
    group.add_argument(
        "--lr-begin",
        type=float,
        default=1e-7,
        help="The initial learning rate."
    )
    group.add_argument(
        "--lr-max",
        type=float,
        default=5e-5,
        help="The learning rate peak at end of warmup process."
        "It's as same as lr-begin if we use constant lr",
    )
    group.add_argument(
        "--grad-accum-step-incr-style",
        type=str,
        default="linear",
        choices=["constant", "linear", "power"],
        help="grad accum step incr function.",
    )
    group.add_argument(
        "--start-grad-accum",
        type=int,
        default=1,
        help="The initial grad accum step",
    )
    group.add_argument(
        "--end-grad-accum",
        type=int,
        default=10,
        help="The final grad accum step",
    )

    '''
    group.add_argument(
        "--override-opt-param-scheduler",
        action="store_true",
        help="Reset the values of the scheduler (learning rate,"
        "warmup iterations, minimum learning rate, maximum "
        "number of iterations, and decay style from input "
        "arguments and ignore values from checkpoints. Note"
        "that all the above values will be reset.",
    )
    group.add_argument(
        "--use-checkpoint-opt-param-scheduler",
        action="store_true",
        help="Use checkpoint to set the values of the scheduler "
        "(learning rate, warmup iterations, minimum learning "
        "rate, maximum number of iterations, and decay style "
        "from checkpoint and ignore input arguments.",
    )
    '''
    group.add_argument(
        "--override-opt-param-scheduler",
        type=str2bool,
        default=False,
        nargs="?",      
        const=False,     
        help="Reset the values of the scheduler (learning rate,"
        "warmup iterations, minimum learning rate, maximum "
        "number of iterations, and decay style from input "
        "arguments and ignore values from checkpoints. Note"
        "that all the above values will be reset.",
    )
    group.add_argument(
        "--use-checkpoint-opt-param-scheduler",
        type=str2bool,
        default=True,
        nargs="?",      
        const=True,     
        help="Use checkpoint to set the values of the scheduler "
        "(learning rate, warmup iterations, minimum learning "
        "rate, maximum number of iterations, and decay style "
        "from checkpoint and ignore input arguments.",
    )

    return parser

def _add_regularization_args(parser):
    group = parser.add_argument_group(title="regularization")
    group.add_argument(
        "--hidden-dropout",
        type=float,
        default=0.1,
        help="Dropout probability for hidden state transformer.",
    )
    group.add_argument(
        "--weight-decay-incr-style",
        type=str,
        default="constant",
        choices=["constant", "linear", "cosine"],
        help="Weight decay increment function.",
    )
    group.add_argument(
        "--weight-decay",
        type=float,
        default=0.1,
        help="Weight decay coefficient for L2 regularization.",
    )
    group.add_argument(
        "--start-weight-decay",
        type=float,
        default=0.1,
        help="Initial weight decay coefficient for L2 regularization.",
    )
    group.add_argument(
        "--end-weight-decay",
        type=float,
        default=0.1,
        help="End of run weight decay coefficient for L2 regularization.",
    )
    group.add_argument(
        "--clip-grad",
        type=float,
        default=1.0,
        help="Gradient clipping based on global L2 norm.",
    )
    group.add_argument(
        "--adam-beta1",
        type=float,
        default=0.9,
        help="First coefficient for computing running averages "
        "of gradient and its square",
    )
    group.add_argument(
        "--adam-beta2",
        type=float,
        default=0.95,
        help="Second coefficient for computing running averages "
        "of gradient and its square",
    )
    group.add_argument(
        "--adam-eps",
        type=float,
        default=1e-08,
        help="Term added to the denominator to improve" "numerical stability",
    )
    group.add_argument(
        "--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd"
    )

    return parser

def _add_validation_args(parser):
    group = parser.add_argument_group(title="validation")
    group.add_argument(
        "--split",
        type=str,
        default="90,10",
        help="Comma-separated list of proportions for training"
        " and validation. For example the split `90,10` will use 90%% of data for training, 10%% for validation",
    )
    group.add_argument(
        "--eval-iters-COP",
        type=int,
        default=100,
        help="Number of iterations to run for COP task evaluation, 0 for entire problem set"
    )
    group.add_argument(
        "--eval-iters-RL",
        type=int,
        default=5,
        help="Number of iterations to run for RL task evaluation"
    )
    group.add_argument(
        "--eval-interval",
        type=int,
        default=5,
        help="Interval between calculating eval loss on validation set.",
    )
    group.add_argument(
        "--eval-policy-interval",
        type=int,
        default=5,
        help="Interval between policy rollouting on problem set.",
    )
    group.add_argument(
        "--eval-dataset-names", 
        nargs="*", 
        default=[], 
        help="RL dataset names that used to train and test"
    )
    group.add_argument(
        "--eval-env-names", 
        nargs="*", 
        default=[], 
        help="RL env names that used to test"
    )
    group.add_argument(
        "--eval-batch-size",
        type=int,
        default=64,
        help="Batch size per model instance when evaluating."
    )
    group.add_argument(
        "--eval-batch-num",
        type=int,
        default=500,
        help="Batch num per training epoch when evaluating."
    )
    group.add_argument(
        "--problem-batch-size",
        type=int,
        default=100,
        help="Vector env num per model instance when evaluating."
    )
    group.add_argument(
        "--problem-batch-num",
        type=int,
        default=20,
        help="Probelm batch num per evaluating epoch."
    )
    parser.add_argument(
        "--use-ddp-env", 
        type=str2bool, 
        default=False, 
        nargs="?", 
        const=True,
        help="Whether to use ddp env in evaluation.",
    )
    parser.add_argument(
        "--use-default-policy-obj", 
        type=str2bool, 
        default=False, 
        nargs="?", 
        const=True,
        help="Whether to use the default random policy obj value to calculate epi quality in evaluation."
        "Default random policy obj value is the average result calculated on 10,000 samples"
    )
    '''
    group.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        help="Batch size per model instance when testing."
        "Only use for evaluation, the test dataset is the MDP episode dataset associated with dataset_problem"
    )
    group.add_argument(
        "--test-batch-num",
        type=int,
        default=500,
        help="Batch num per training epoch when testing."
        "Only use for evaluation, the test dataset is the MDP episode dataset associated with dataset_problem"
    )
    '''
    group.add_argument(
        "--eval-max-step-size", 
        type=int, 
        default=100000,
        help="Max rollout step when evaluating."
        "This para is useless for COP task so just set a big number as default"
    )
    parser.add_argument("--strict-length", type=str2bool, default=True)
    parser.add_argument("--minimal-expert-data", type=str2bool, default=True)

    return parser

def _add_dataset_args(parser):
    group = parser.add_argument_group(title="dataset")
    # XXX(ziyu): now only test mmap
    group.add_argument(
        "--data-path",
        nargs="*",
        default=None,
        help="Path to the training dataset. Accepted format:"
        "1) a single data path with its dataset type, 2) multiple datasets in the"
        "form: dataset1-weight dataset1-path dataset1-type dataset2-weight "
        'dataset2-path dataset2-type ..., dataset types currently are ["rl", "nlp"]',
    )
    group.add_argument(
        "--traj-type",
        type=str,
        default="all",
        choices=["all", "complete"],
        help="The length of epi fragment in training dataset."
        "Set `all` to cut raw episode into fragments of at least 2 timestep lenghth to construct training dataset"
        "Set `complete` to use subseq of raw episodes which has the same length (trans num) as model supported"
    )
    group.add_argument(
        "--eval-problem-set", 
        type=str,
        default='problem', 
        choices={'problem', 'train_problem'},
        help="The problem dataset used to eval policy during training."
    )

    group.add_argument(
        "--data-impl",
        type=str,
        default="infer",
        choices=["lazy", "cached", "mmap", "infer"],
        help="Implementation of indexed datasets.",
    )
    
    # RL Dataset
    group.add_argument(
        "--use-prompt",
        type=str2bool,
        default=True,
        help="enable prompt or not. (0: off, 1: on)",
    )
    group.add_argument(
        "--use-prefix",
        type=str2bool,
        default=False,
        help="enable prefix or not. (0: off, 1: on)",
    )
    group.add_argument(
        "--disable-visited-obs",
        type=str2bool,
        default=False,
        help="Remove visited obs when constructing MDP episode like DB1",
    )

    group.add_argument(
        "--prompt-ratio",
        type=float,
        default=0.5,
        help="Ratio of prepending prompt in a rl sequence.",
    )

    group.add_argument(
        "--prompt-prob",
        type=float,
        default=0.25,
        help="Probability of prepending prompt to a rl sequence.",
    )
    group.add_argument(
        "--prompt-at-final-transition-prob",
        type=float,
        default=0.5,
        help="Probability of use the last transitions of an episode.",
    )
    group.add_argument(
        "--mask-prompt-action-loss",
        type=str2bool,
        default=True,
        help="Whether to ignore action loss for prompt actions.",
    )

    group.add_argument(
        "--prompt-strategy",
        type=str,
        default="stochastic_timestep;moving_prompt",
        choices={
            "stochastic_timestep;moving_prompt",
            "stochastic_subseq;moving_prompt",
            "stochastic_timestep;fixed_prompt",
            "stochastic_subseq;fixed_prompt",
        },
    )

    return parser

def _add_logging_args(parser):
    group = parser.add_argument_group(title="logging")
    group.add_argument(
        "--exp-profile",
        type=str,
        default=None,
        help="The brief name of this experiment, which will be used to \
        construct the ckpt storage path, and used for the group name on wandb log ",
    )
    group.add_argument(
        "--tensorboard-dir",
        type=str,
        default=None,
        help="Write TensorBoard logs to this directory.",
    )
    group.add_argument(
        "--tensorboard-queue-size",
        type=int,
        default=1000,
        help="Size of the tensorboard queue for pending events "
        "and summaries before one of the ‘add’ calls forces a "
        "flush to disk.",
    )
    group.add_argument(
        "--wandb",
        type=str2bool,
        default=False,
        nargs="?",      
        const=True,     
        help="Synchronize the experiment curve to wandb",
    )
    group.add_argument(
        "--policy-logger",
        type=str2bool,
        default=False,
        nargs="?",      
        const=True,  
        help="whether to check the generated episodes during training",
    )
    group.add_argument(
        "--traindata-logger",
        type=str2bool,
        default=False,
        nargs="?",      
        const=True,  
        help="whether to log the sample idx during training",
    )

    return parser

def _add_checkpointing_args(parser):
    group = parser.add_argument_group(title="checkpointing")
    group.add_argument(
        "--save-ckpt",
        type=str2bool,
        default=False,
        nargs="?",      
        const=True,
        help="whether to save ckpt during training",
    )
    group.add_argument(
        "--save-snapshot",
        type=str2bool,
        default=False,
        nargs="?",      
        const=True,
        help="whether to save snapshot during DDP training",
    )
    group.add_argument(
        "--save-dir",
        type=str,
        default=None,
        help="Output directory to save checkpoints to.",
    )
    group.add_argument(
        "--save-strategy",
        type=str,
        default='interval',
        choices=["best", "interval"],
        help="Ckpt saving strategy.",
    )
    group.add_argument(
        "--save-interval",
        type=int,
        default=None,
        help="Number of iterations between checkpoint saves.",
    )
    return parser

def _add_initialization_args(parser):
    group = parser.add_argument_group(title="initialization")
    group.add_argument(
        "--seeds", 
        type=int, 
        nargs='+',
        default=[42, 43, 44], 
        help="Random seed for numpy/torch"
    )

    group.add_argument(
        "--init-method-std",
        type=float,
        default=0.02,
        help="Standard deviation of the zero mean normal "
        "distribution used for weight initialization.",
    )
    return parser


def _add_finetune_args(parser):
    group = parser.add_argument_group(title="finetune")
    group.add_argument(
        "--num-rl-fewshot_episodes",
        type=int,
        default=None,
        help="Number of episoes used when finetuning on RL environment",
    )
    return parser