"""
Default arguments for the CITNP package.
"""

import argparse


def parse_default_args():
    """
    Parses command-line arguments for the CITNP package, using defaults that
    match return_default_args() from your script.
    """
    parser = argparse.ArgumentParser(
        description="Default arguments for the CITNP package."
    )

    parser.add_argument(
        "--work_dir",
        type=str,
        default="./",
        help="Path to the working directory.",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default="3var_neuralnet",
        help="Name of the data file to use.",
    )

    parser.add_argument(
        "--seed", type=int, default=0, help="Seed for random number generation."
    )
    parser.add_argument(
        "--no-normalise",
        action="store_false",
        dest="normalise",
        default=True,
        help="Disable normalisation in the collator (default is normalise=True).",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        default=False,
        help="Whether to run in debug mode.",
    )
    parser.add_argument("--model_type", choices=["cnp", "locallatent"], default="cnp")
    parser.add_argument("--d_model", type=int, default=256, help="Model dimension.")
    parser.add_argument("--emb_depth", type=int, default=2, help="Embedding depth.")
    parser.add_argument("--decoder_depth", type=int, default=0, help="Decoder depth.")
    parser.add_argument(
        "--dim_feedforward",
        type=int,
        default=1024,
        help="Feedforward network dimension.",
    )
    parser.add_argument(
        "--nhead", type=int, default=8, help="Number of attention heads."
    )
    parser.add_argument(
        "--num_layers_encoder",
        type=int,
        default=8,
        help="Number of layers in the encoder.",
    )
    parser.add_argument(
        "--sample_attn_mode",
        type=str,
        default="MHCA",
        help="Type of sample attention: MHSA or MHCA.",
    )
    parser.add_argument(
        "--linear_attention",
        action="store_true",
        default=False,
        help="Whether to use linear attention.",
    )
    parser.add_argument(
        "--mean_loss_across_samples",
        action="store_true",
        default=False,
        help="Whether to average the loss across samples.",
    )
    # MoG args
    parser.add_argument(
        "--num_mixture_components",
        type=int,
        default=5,
        help="Number of components in MoG.",
    )
    # Local VI args
    parser.add_argument(
        "--num_z_samples_train",
        type=int,
        default=16,
        help="Number of z samples for training.",
    )
    parser.add_argument(
        "--num_z_samples_eval",
        type=int,
        default=16,
        help="Number of z samples for evaluation.",
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help='Computation device to use (e.g., "cuda" or "cpu").',
    )

    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help='Torch data type (e.g., "bfloat16" or "float32").',
    )

    parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
    parser.add_argument(
        "--sample_size", type=int, default=500, help="Number of samples."
    )
    parser.add_argument(
        "--num_variables", type=int, default=2, help="Number of variables."
    )
    parser.add_argument("--intervention_range_multiplier", type=int, default=4)
    parser.add_argument(
        "--function_generator",
        type=str,
        default="resnet",
        help="Type of function to generate.",
    )
    parser.add_argument(
        "--graph_type", nargs="+", default=["ER"], help="Graph type(s)."
    )
    parser.add_argument(
        "--graph_degrees",
        nargs="+",
        type=int,
        default=[1, 2, 3],
        help="Degree(s) for the graph.",
    )
    parser.add_argument(
        "--iterations_per_epoch",
        type=int,
        default=10,
        help="Number of iterations per epoch.",
    )
    parser.add_argument(
        "--learning_rate", type=float, default=0.0005, help="Learning rate."
    )
    parser.add_argument("--run_name", type=str, default=None, help="Name of the run.")
    parser.add_argument(
        "--entity_name",
        type=str,
        default="npcausalinf",
        help="Name of the wandb entity.",
    )

    # Training args
    parser.add_argument("--epochs", type=int, default=2, help="Number of epochs.")
    parser.add_argument(
        "--no_wandb",
        action="store_true",
        default=False,
        help="Dont use Weights & Biases.",
    )
    parser.add_argument(
        "--lr_warmup_ratio",
        type=float,
        default=0.02,
        help="Warmup ratio for learning rate.",
    )
    parser.add_argument(
        "--results_path",
        type=str,
        default="./experiments",
        help="Path where to store the results.",
    )
    parser.add_argument("--log_step", type=int, default=200, help="Log step.")

    parser.add_argument(
        "--save_checkpoint_every_n_steps",
        type=int,
        default=10000,
        help="Number of steps after which to save checkpoint.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=8,
        help="Number of workers for data loading.",
    )
    parser.add_argument(
        "--model_chkpt_path",
        type=str,
        default="",
        help="Checkpoint path for testing.",
    )

    return parser.parse_args()


def overwrite_debug_args(args):
    """
    Overwrites the arguments when running in debug mode.
    """
    args.sample_size = 100
    args.iterations_per_epoch = 100
    args.epochs = 1
    args.d_model = 32
    args.dim_feedforward = 32
    args.num_layers_encoder = 4
    args.lr_warmup_ratio = 0.0
    args.no_wandb = True
    args.run_name = "debug"
    args.nhead = 2
    return args
