import os
import argparse


def get_args():
    ## ArgumentParser
    parser = argparse.ArgumentParser(description="Running TSB-AD")
    parser.add_argument(
        "--filename",
        "-f",
        type=str,
        default=None,
    )
    parser.add_argument("--data_direc", "-dd", type=str, default="Datasets/TSB-AD-U/")
    parser.add_argument("--save_results", "-sr", type=int, default=0)
    parser.add_argument("--save_models", "-sm", type=int, default=0)
    parser.add_argument("--AD_Name", "-ad", type=str, default="TSPulse_FT_All")
    parser.add_argument("--plot", "-pl", type=int, default=0)
    parser.add_argument(
        "--save_dir",
        "-sd",
        type=str,
        default="tspulse/ad/results/",
    )
    parser.add_argument("--windowed_detector", "-wd", type=int, default=0)
    parser.add_argument("--aggr_win_size", "-aws", type=int, default=None)
    parser.add_argument("--use_pipeline", "-up", type=int, default=0)
    parser.add_argument("--use_ts_from_fft", "-utf", type=int, default=0)
    parser.add_argument("--use_forecast", "-uf", type=int, default=0)
    parser.add_argument("--ensemble_outputs", "-eo", type=int, default=0)
    parser.add_argument("--tspulse_decoder_mode", "-tdm", type=str, default=None)
    parser.add_argument("--model_path", "-mp", type=str, default=None)
    parser.add_argument("--do_tuning", "-dt", type=int, default=0)
    parser.add_argument("--freeze_backbone", "-fb", type=int, default=0)
    parser.add_argument("--dataset_name", "-dn", type=str, default=None)
    parser.add_argument("--use_finetuned_models", "-ufm", type=int, default=0)
    parser.add_argument("--enable_fft_prob_loss", "-efpl", type=int, default=1)
    parser.add_argument("--window_position", "-wp", type=str, default="last")
    parser.add_argument("--batch_size", "-bs", type=int, default=50_000)
    parser.add_argument("--finetune_num_epochs", "-fne", type=int, default=20)

    args = parser.parse_args()
    args.save_results = bool(args.save_results)
    args.save_models = bool(args.save_models)
    args.windowed_detector = bool(args.windowed_detector)
    args.use_ts_from_fft = bool(args.use_ts_from_fft)
    args.plot = bool(args.plot)
    args.use_pipeline = bool(args.use_pipeline)
    args.use_forecast = bool(args.use_forecast)
    args.ensemble_outputs = bool(args.ensemble_outputs)
    args.do_tuning = bool(args.do_tuning)
    args.freeze_backbone = bool(args.freeze_backbone)
    args.use_finetuned_models = bool(args.use_finetuned_models)
    args.enable_fft_prob_loss = bool(args.enable_fft_prob_loss)

    if [args.use_ts_from_fft, args.use_forecast, args.ensemble_outputs].count(True) > 1:
        raise ValueError("Only one can be True among these: use_ts_from_fft, use_forecast, ensemble_outputs.")

    os.makedirs(args.save_dir, exist_ok=True)
    if args.plot:
        os.makedirs(os.path.join(args.save_dir, "plots"), exist_ok=True)

    args.save_prefix = "uni"
    if "AD-M" in args.data_direc:
        args.save_prefix = "multi"

    if "ZS" in args.AD_Name:
        print("Forcing TSPulse's `decoder_mode` to `common_channel` since zero-shot workflow is being used.")
        args.tspulse_decoder_mode = "common_channel"

    decoder_model_prefix = "C"
    if args.tspulse_decoder_mode == "mix_channel":
        decoder_model_prefix = "M"

    if args.use_finetuned_models:
        if args.AD_Name != "TSPulse_ZS":
            raise ValueError("`AD_Name` must be `TSPulse_ZS` if `use_finetuned_models=1`.")
        if args.model_path is None:
            raise ValueError(
                "`model_path` must be set to the directory containing all finetuned TSPulse models when `use_finetuned_models=1`."
            )

    args.save_prefix += (
        f"_AD-{args.AD_Name}"
        f"_utf-{int(args.use_ts_from_fft)}"
        f"_uf-{int(args.use_forecast)}"
        f"_eo-{int(args.ensemble_outputs)}"
        f"_wd-{int(args.windowed_detector)}"
        f"_aws-{str(args.aggr_win_size)}"
        f"-tdm-{decoder_model_prefix}"
    )

    return args
