import argparse
import os
import torch

def int_or_float(value):
    if '.' in value:
        return float(value)
    return int(value)

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-location",
        type=str,
        default="/path/to/datasets",
        help="The root directory for the datasets.",
    )
    parser.add_argument(
        "--eval-datasets",
        default=None,
        type=lambda x: x.split(","),
        help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
    )
    parser.add_argument(
        "--loss-fn",
        default='entropy',
        type=str,
        help="Loss function to use.",
        choices=["entropy", "cross_entropy"]
    )
    parser.add_argument(
        "--ind-dataset",
        default=None,
        type=str,
        help="Which single dataset to use for starting the learning. ",
    )
    parser.add_argument(
        "--number-of-random-matrices",
        type=int,
        default=3,
        help="Number of random matrices to generate for each task vector.",
    )
    parser.add_argument(
        "--lp-reg",
        default=None,
        type=int,
        choices=[1, 2],
        help="Regularisation applied to the learned coefficients."
    )
    parser.add_argument(
        "--blockwise-coef",
        default=False,
        action="store_true",
        help="Use different coefficients on different parameter blocks."
    )
    parser.add_argument(
        "--subsample",
        default=1.0,
        type=int_or_float,
        help="Subsample the datasets with a float"
    )
    parser.add_argument(
        "--start-from-block-number",
        type=int,
        default=0,
        help="Start from this block number.",
    )
    parser.add_argument(
        "--train-dataset",
        default=None,
        type=lambda x: x.split(","),
        help="Which dataset(s) to patch on.",
    )
    parser.add_argument(
        "--source-dataset-name",
        type=str,
        default=None,
        help="The name of the single source dataset to use for creating the task vector.",
    )
    parser.add_argument(
        "--source-dataset",
        type=str,
        default=None,
        required=False,
        help="Source dataset used to initialize the model for full fine-tuning.",
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        default=None,
        help="Name of the experiment, for organization purposes only.",
    )
    parser.add_argument(
        "--results-db",
        type=str,
        default=None,
        help="Where to store the results, else does not store",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="ViT-B-32",
        help="The type of model (e.g. ViT-B-32).",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
    )
    parser.add_argument(
        "--num-grad-accumulation",
        type=int,
        default=1,
        help="Number of gradient accumulation steps.",
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=4,
        help="Number of data loader workers.",
    )
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
    parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
    parser.add_argument(
        "--warmup_length",
        type=int,
        default=500,
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--load",
        type=lambda x: x.split(","),
        default=None,
        help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",  # noqa: E501
    )
    parser.add_argument(
        "--save",
        type=str,
        default="/path/to/save/results",
        help="Where to load zero-shot weights and task vectors",
    )
    parser.add_argument(
        "--logdir",
        type=str,
        default='/path/to/save/logs',
        help="Where to save results",
    )
    parser.add_argument(
        "--cache-dir",
        type=str,
        default=None,
        help="Directory for caching features and encoder",
    )
    parser.add_argument(
        "--openclip-cachedir",
        type=str,
        default=os.path.expanduser("~/openclip-cachedir/open_clip"),
        help="Directory for caching models from OpenCLIP",
    )
    parser.add_argument(
        "--world-size",
        type=int,
        default=1,
        help="Number of processes for distributed training.",
    )
    parser.add_argument(
        "--checkpoint-every",
        type=int,
        default=-1,
        help="How often to checkpoint the model.",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=None,
        help="Port for distributed training.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed.",
    )
    parser.add_argument(
        "--finetuning-mode",
        default='standard',
        choices=["standard", "linear"],
        help="Whether to use linearized models or not.",
    )
    parser.add_argument(
        "--partition",
        type=int,
        default=None,
        help="Run atlas x K where the task vectors are randomly partitioned",
    )
    parser.add_argument(
        "--total-trainable-params",
        type=int,
        default=None,
        help="Specify the exact total number of trainable coefficients. Overrides --partition and --blockwise-coef if set.",
    )
    parser.add_argument(
        "--target-dataset-name",
        type=str,
        default="MNIST",
        help="Target dataset for task arithmetic",
    )
    parser.add_argument(
        "--topK",
        type=int,
        default=76,
        help="Number of top components to keep in merge_task_vectors."
    )
    parser.add_argument(
        "--svd-threshold",
        type=float,
        default=0.1,
        help="Threshold for zeroing out top singular values in SVD.",
    )
    parser.add_argument(
        "--no-use-half",
        action="store_true",
        default=False,
        help="If set, don't use half precision (float16) for task vectors. Use full precision (float32) instead, which is needed for SVD operations.",
    )
    parser.add_argument(
        "--keep-top-values",
        action="store_true",
        default=False,
        help="If set, keep top singular values and zero out the rest. If not set (default), zero out top singular values and keep the rest.",
    )
    parser.add_argument(
        "--sorting-descending",
        action="store_true",
        default=False,
        help="If set, sort singular values in descending order for merge_task_vectors selection. Defaults to False (ascending).",
    )
    parser.add_argument(
        "--end-index",
        type=int,
        default=24,
        help="Index to end aggregation at.",
    )
    parser.add_argument(
        "--resume-from-idx",
        type=int,
        default=0,
        help="Index to resume aggregation level from.",
    )
    parser.add_argument(
        "--isoc",
        default=False,
        action="store_true",
        help="Use the merge_task_vectors function to merge task vectors with SVD-based averaging."
    )
    parser.add_argument(
        "--load-only-classification-head",
        action="store_true",
        default=False,
        help="If set, only the classification head will be loaded from the checkpoint file."
    )
    parser.add_argument(
        "--corruption",
        type=str,
        default="impulse_noise",
        choices=CORRUPTION_NAMES,
        help="Type of corruption to apply."
    )
    parser.add_argument(
        "--severity-number",
        type=int,
        default=5,
        help="Severity of the corruption applied during evaluation (1-5)."
    )

    parsed_args = parser.parse_args()
    parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"

    if parsed_args.load is not None and len(parsed_args.load) == 1:
        parsed_args.load = parsed_args.load[0]
    return parsed_args

def get_corruption_names():
    return [
        "gaussian_noise",
        "shot_noise",
        "impulse_noise",
        "speckle_noise",
        "gaussian_blur",
        "glass_blur",
        "defocus_blur",
        "motion_blur",
        "zoom_blur",
        "snow",
        "spatter",
        "contrast",
        "brightness",
        "saturate",
        "jpeg_compression",
        "pixelate",
        "elastic_transform",
    ]

CORRUPTION_NAMES = get_corruption_names()
