import os
import argparse

import torch

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-location",
        type=str,
        default=os.path.expanduser('~/data'),
        help="The root directory for the datasets.",
    )
    parser.add_argument(
        "--eval-datasets",
        default=None,
        type=lambda x: x.split(","),
        help="Which datasets to use for evaluation. Split by comma, e.g. CIFAR101,CIFAR102."
             " Note that same model used for all datasets, so much have same classnames"
             "for zero shot.",
    )
    parser.add_argument(
        "--train-dataset",
        default=None,
        help="For fine tuning or linear probe, which dataset to train on",
    )
    parser.add_argument(
        "--template",
        type=str,
        default=None,
        help="Which prompt template is used. Leave as None for linear probe, etc.",
    )
    parser.add_argument(
        "--classnames",
        type=str,
        default="openai",
        help="Which class names to use.",
    )
    parser.add_argument(
        "--alpha",
        default=[0.5],
        nargs='*',
        type=float,
        help=(
            'Interpolation coefficient for ensembling. '
            'Users should specify N-1 values, where N is the number of '
            'models being ensembled. The specified numbers should sum to '
            'less than 1. Note that the order of these values matter, and '
            'should be the same as the order of the classifiers being ensembled.'
        )
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        default=None,
        help="Name of the experiment, for organization purposes only."
    )
    parser.add_argument(
        "--results-db",
        type=str,
        default=None,
        help="Where to store the results, else does not store",
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="The type of model (e.g. RN50, ViT-B/32).",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for reproducibility.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        help="Learning rate."
    )
    parser.add_argument(
        "--wd",
        type=float,
        default=0.1,
        help="Weight decay"
    )
    parser.add_argument(
        "--ls",
        type=float,
        default=0.0,
        help="Label smoothing."
    )
    parser.add_argument(
        "--warmup_length",
        type=int,
        default=500,
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--load",
        type=lambda x: x.split(","),
        default=None,
        help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
    )
    parser.add_argument(
        "--save",
        type=str,
        default=None,
        help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
    )
    parser.add_argument(
        "--freeze-encoder",
        default=False,
        action="store_true",
        help="Whether or not to freeze the image encoder. Only relevant for fine-tuning."
    )
    parser.add_argument(
        "--cache-dir",
        type=str,
        default=None,
        help="Directory for caching features and encoder",
    )
    parser.add_argument(
        "--fisher",
        type=lambda x: x.split(","),
        default=None,
        help="TODO",
    )
    parser.add_argument(
        "--fisher_floor",
        type=float,
        default=1e-8,
        help="TODO",
    )
    # Regularization arguments
    parser.add_argument("--lambda_reg", type=float, default=0.0, help="Regularization strength")
    parser.add_argument("--max_degree", type=int, default=40, help="Maximum degree for polynomial fitting")
    parser.add_argument("--resolution", type=int, default=100, help="Resolution for interpolation")
    parser.add_argument("--miu", type=float, default=1.0, help="Miu parameter")
    parser.add_argument("--have_const", action="store_true", default=False, help="Have constant term")
    parser.add_argument("--use_norm", action="store_true", default=False, help="Use normalization")
    parser.add_argument("--num_pairs", type=int, default=1, help="Number of pairs for regularization")
    parser.add_argument("--label", action="store_true", default=False, help="Use label information")
    parser.add_argument("--sam", action="store_true", default=False, help="Use SAM optimizer")
    parser.add_argument("--smooth", action="store_true", default=False, help="Use smoothing")
    parser.add_argument("--reg_anneal", action="store_true", default=False, help="Use regularization annealing")
    parser.add_argument("--min_lambda_reg", type=float, default=0.0, help="Minimum lambda regularization")
    parser.add_argument("--dryrun", action="store_true", default=False, help="Dry run")
    parser.add_argument("--warmup_type", type=str, default="normal", help="Type of warmup")
    parser.add_argument("--pca_reg", action="store_true", default=False, help="Use PCA for regularization")
    parser.add_argument('--square', action='store_true', default=False, help='use square instead of abs')
    parser.add_argument('--degree_mode', type=str, default='index', help='degrees vector, index means 0 1 2 3 ..., else means 0 0 1 1 1 ...')
    parser.add_argument("--random_alpha", action="store_true", default=False, help="Use random alpha values")
    parser.add_argument("--pca_k", type=int, default=1, help="Number of PCA components to keep")
    parsed_args = parser.parse_args()
    parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    if parsed_args.load is not None and len(parsed_args.load) == 1:
        parsed_args.load = parsed_args.load[0]
    return parsed_args
