import argparse
import os


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--VLM_Base', default="ViT-B/32",
                        choices=["ViT-B/32", "RN50","audio_full","audio_partial"], type=str)
    parser.add_argument('--dataset', default="eurosat",
                        choices=["cifar100", "imagenet1k", "sun397", "stanford_cars","DTD","pets", "Food101",\
                                 "Flowers102","FGVCAircraft","eurosat","esc50","urbansound8k","mscoco",\
                                 "flickr30k","esc50_querying","urbansound8k_querying"], type=str)

    parser.add_argument('--hyper_level', default=5, type=int)

    parser.add_argument('--hyper_strength', default=0.5, type=float)

    parser.add_argument('--wavelet_base', default="db6", type=str)

    parser.add_argument('--sims_mode', default="Wave", choices=["Org", "Wave", "DN"], type=str)

    parser.add_argument('--num_samples', default=100, type=int)

    parser.add_argument('--num_experiments', default=1, type=int)

    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')

    parser.add_argument('--batch_size', type=int, default=1,
                        help='batch_size')

    parser.add_argument('--print_freq', type=int, default=50,
                        help='print frequency')
    parser.add_argument('--acc_type', type=str, default="acc")

    parser.add_argument('--retrival_type', type=str, default="text2audio", choices=["image2text,text2image,text2audio"])

    return parser.parse_args()
def get_features_paths(args):
    dataset_name = args.dataset

    if dataset_name == "flickr30k":
        folder_name = "/datasets/flickr30k"
    if dataset_name == "mscoco":
        folder_name = "/datasets/mscoco/mscoco2014"
    if dataset_name == "cifar100":
        if args.VLM_Base == "ViT-B/32":
            folder_name = "/vit-b32/features/cifar-100"
        elif args.VLM_Base == "RN50":
            folder_name = "/vit-res50/features/cifar-100"
    elif dataset_name == "sun397":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/sun397"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/sun397"
    elif dataset_name == "imagenet1k":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/imagenet1k"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/imagenet1k"
    elif dataset_name == "stanford_cars":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/cars"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/cars"
    elif dataset_name == "pets":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/pets"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/pets"
    elif dataset_name == "DTD":
        if args.VLM_Base == "ViT-B/32":
            folder_name = "/vit-b32/features/DTD"
        elif args.VLM_Base == "RN50":
            folder_name = "/vit-res50/features/DTD"
    elif dataset_name == "Food101":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/Food101"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/Food101"
    elif dataset_name == "Flowers102":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/Flowers102"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/Flowers102"
    elif dataset_name == "eurosat":
        if args.VLM_Base == "ViT-B/32":
            folder_name =  "/vit-b32/features/eurosat"
        elif args.VLM_Base == "RN50":
            folder_name =  "/vit-res50/features/eurosat"
    elif dataset_name == "FGVCAircraft":
        if args.VLM_Base == "ViT-B/32":
            folder_name = "/vit-b32/features/FGVCAircraft"
        elif args.VLM_Base == "RN50":
            folder_name = "/vit-res50/features/FGVCAircraft"
    elif dataset_name == "esc50" or dataset_name =="esc50_querying":
        if args.VLM_Base == "audio_full":
            folder_name = "/audioclip_full/esc50"
        elif args.VLM_Base == "audio_partial":
            folder_name = "/audioclip_partial/esc50"
    elif dataset_name == "urbansound8k" or dataset_name == "urbansound8k_querying":
        if args.VLM_Base == "audio_full":
            folder_name = "/audioclip_full/urbansound8k"
        elif args.VLM_Base == "audio_partial":
            folder_name = "/audioclip_partial/urbansound8k"
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    return folder_name