import os
import argparse

import torch


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-dir", type=str, default=os.path.expanduser("~/data"), help="The root directory for the datasets."
    )
    parser.add_argument(
        "--eval-datasets",
        default=None,
        type=lambda x: x.split(","),
        help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
    )
    parser.add_argument(
        "--train-dataset", default=None, type=lambda x: x.split(","), help="Which dataset(s) to patch on."
    )
    parser.add_argument(
        "--exp_name", type=str, default=None, help="Name of the experiment, for organization purposes only."
    )
    parser.add_argument("--results-db", type=str, default=None, help="Where to store the results, else does not store")
    parser.add_argument("--model", type=str, default="ViT-B-32", help="The type of model (e.g. RN50, ViT-B-32).")
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--num-grad-accumulation", type=int, default=1, help="Number of gradient accumulation steps.")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
    parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
    parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
    parser.add_argument("--warmup_length", type=int, default=500)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument(
        "--load",
        type=lambda x: x.split(","),
        default=None,
        help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
    )
    parser.add_argument("--cache-dir", type=str, default=None, help="Directory for caching features and encoder")
    parser.add_argument(
        "--openclip-cachedir",
        type=str,
        default=os.path.expanduser("~/openclip-cachedir/open_clip"),
        help="Directory for caching models from OpenCLIP",
    )
    parser.add_argument("--world-size", type=int, default=1, help="Number of processes for distributed training.")
    parser.add_argument("--checkpoint-every", type=int, default=-1, help="How often to checkpoint the model.")
    parser.add_argument("--port", type=int, default=12355, help="Port for distributed training.")
    parser.add_argument("--seed", type=int, default=None, help="Random seed.")
    parser.add_argument(
        "--n-eval-points",
        type=int,
        default=21,
        help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
    )
    parser.add_argument("--finetuning-mode", type=str, default="standard", help="Finetuned mode; standard for nonlinear finetune; none for zeroshot")
    parser.add_argument("--model-location", type=str, default="/shared-network/rskorobogat/papers/model-merging/checkpoints", help="Directory for model location")
    parser.add_argument("--data-location", type=str, default="/shared-network/rskorobogat/papers/model-merging/data", help="Directory for data location")

    parsed_args = parser.parse_args()
    parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"

    if parsed_args.load is not None and len(parsed_args.load) == 1:
        parsed_args.load = parsed_args.load[0]
    return parsed_args
