import argparse

data_dir = "DIRECTORY TO ORIGINAL DATA"
save_checkpoint = "DIRECTORY TO SAVE MODEL"

batch_size=4

def str2bool(string):
    if string == "true" or string == "True":
        return True
    else:
        return False


def get_args_dodh():
    parser = argparse.ArgumentParser(
        "TSADC with Contaminated Dta."
    )

    # General args
    parser.add_argument(
        "--save_dir",
        type=str,
        default=save_checkpoint,
    )
    parser.add_argument(
        "--gpus",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--gpu_id",
        default=[3],
        type=int,
        nargs="+",
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0')
    parser.add_argument(
        "--load_model_path",
        type=str,
        default=None
    )
    parser.add_argument(
        "--do_train",
        default=True,
        type=str2bool,
    )
    parser.add_argument(
        "--raw_data_dir",
        type=str,
        default=None,
    )

    parser.add_argument(
        "--max_seq_len",
        type=int,
        default=7500,
    )
    parser.add_argument(
        "--output_seq_len",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--rand_seed",
        type=int,
        default=123)

    parser.add_argument(
        "--num_nodes",
        type=int,
        default=16,
    )

    ### Decontaminator args
    parser.add_argument(
        "--masking",
        type=str,
        default='bom',
        choices=('rm', 'mnr', 'bom'),
    )
    parser.add_argument(
        "--masking_r",
        default=1500,
    )

    parser.add_argument(
        "--masking_r_test",
        default=1500,
    )
    parser.add_argument(
        "--in_channels",
        default=16,
    )
    parser.add_argument(
        "--out_channels",
        default=16,
    )
    parser.add_argument(
        "--num_res_layers",
        default=8,
    )
    parser.add_argument(
        "--res_channels",
        default=256,
    )
    parser.add_argument(
        "--skip_channels",
        default=256,
    )
    parser.add_argument(
        "--diffusion_step_embed_dim_in",
        default=128,
    )
    parser.add_argument(
        "--diffusion_step_embed_dim_mid",
        default=128,
    )
    parser.add_argument(
        "--diffusion_step_embed_dim_out",
        default=128,
    )
    parser.add_argument(
        "--s4_max",
        default=100,
    )
    parser.add_argument(
        "--s4_d_state",
        default=64,
    )
    parser.add_argument(
        "--s4_dropout",
        default=0.0,
    )
    parser.add_argument(
        "--s4_bidirectional",
        default=1,
    )
    parser.add_argument(
        "--s4_layernorm",
        default=1,
    )
    parser.add_argument(
        "--diffuse_T",
        default=200,
    )
    parser.add_argument(
        "--diffuse_beta_0",
        default=0.0001,
    )
    parser.add_argument(
        "--diffuse_beta_T",
        default=0.02,
    )

    parser.add_argument(
        "--step_in_seq",
        default=3750,
    )

    parser.add_argument(
        "--step_in_seq_test",
        default=3750,
    )

    #Variable Dependency Modeling args
    parser.add_argument(
        "--num_temporal_layers",
        type=int,
        default=4
    )
    parser.add_argument(
        "--input_dim",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--output_dim",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--hidden_dim",
        type=int,
        default=128,
    )
    parser.add_argument(
        "--num_gcn_layers",
        type=int,
        default=2,
    )
    parser.add_argument(
        "--activation_fn",
        type=str,
        default="leaky_relu",
    )
    parser.add_argument(
        "--gin_mlp",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--train_eps",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--edge_top_perc",
        type=float,
        default=0.2,
    )
    parser.add_argument(
        "--prune_method",
        type=str,
        default="thresh_abs",
    )
    parser.add_argument(
        "--thresh",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--temporal_pool",
        type=str,
        default="mean",
    )
    parser.add_argument(
        "--interval",
        type=int,
        default=1250,
    )
    parser.add_argument(
        "--negative_slope",
        type=float,
        default=0.2,
    )
    parser.add_argument(
        "--knn",
        type=int,
        default=3,
    )
    parser.add_argument(
        "--graph_learn_metric",
        type=str,
        default="self_attention",
    )
    parser.add_argument(
        "--adj_embed_dim",
        type=int,
        default=16,
    )
    parser.add_argument(
        "--regularizations",
        type=str,
        nargs="+",
        default=["feature_smoothing", "degree", "sparse"],
    )
    parser.add_argument(
        "--residual_weight",
        type=float,
        default=0.6,
    )
    parser.add_argument(
        "--feature_smoothing_weight",
        type=float,
        default=1,
    )
    parser.add_argument(
        "--degree_weight",
        type=float,
        default=0.05,
    )
    parser.add_argument(
        "--sparse_weight",
        type=float,
        default=0.5,
    )
    parser.add_argument(
        "--bidirectional",
        type=str2bool,
        default="false",
    )
    parser.add_argument(
        "--state_dim",
        type=int,
        default=64,
    )
    parser.add_argument(
        "--prenorm",
        type=str2bool,
        default="false",
    )
    parser.add_argument(
        "--postact",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--channels",
        type=int,
        default=1,
    )

    # Anomaly scoring args
    parser.add_argument(
        "--lambda_1",
        type=int,
        default=0.01,
    )
    parser.add_argument(
        "--lambda_2",
        type=int,
        default=1.2,
    )
    parser.add_argument(
        "--tau",
        type=int,
        default=None,
    )

    #Training args
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=batch_size,
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
    )
    parser.add_argument(
        "--scheduler",
        type=str,
        default="timm_cosine",
    )
    parser.add_argument(
        "--t_initial",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--lr_min",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--cycle_decay",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--warmup_lr_init",
        type=float,
        default=1e-6,
    )
    parser.add_argument(
        "--warmup_t",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--cycle_limit",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adamw",
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.1,
    )

    parser.add_argument(
        "--lr_init",
        type=float,
        default="8e-4"
    )
    parser.add_argument(
        "--l2_wd",
        type=float,
        default=5e-3,
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--max_grad_norm",
        type=float,
        default=5.0,
    )
    parser.add_argument(
        "--test_batch_size",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--test_samples",
        type=int,
        default=310,
    )
    parser.add_argument(
        "--patience",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--accumulate_grad_batches",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--pos_weight",
        default=1.0,
        type=float,
    )


    args = parser.parse_args()

    return args
