import sys

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, ImageNet, ImageFolder
from zs_datasets.cars import Standfordcars_test_rn50, Standfordcars_test_b32
from zs_datasets.FGVCAircraft import FGVCAircraft_test
from zs_datasets.dtd import DTD_test_dataset
from zs_datasets.eurosat import eurosat_test
from zs_datasets.sun397 import sun397_test_rn50, sun397_test_b32
from zs_datasets.flower import flowers_test_dataset
from zs_datasets.food import food_test_dataset
from zs_datasets.pets import pets_test_dataset
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)
from PIL import Image
from zs_datasets.imagenet import get_imagenet_dataset
from zs_datasets.esc50 import ESC50
from zs_datasets.us8k import UrbanSound8K


def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=Image.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


def refine_classname(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names


def load_test_datasets(args,preprocess):

    if args.dataset == 'cifar100':
        template = 'a photo of a {}'

        val_dataset = CIFAR100("/datasets/cifar-100", transform=preprocess,
                               download=True, train=False)

        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)

        class_names = val_dataset.classes
        class_names = refine_classname(class_names)

    elif args.dataset == "imagenet1k":
        template = 'a photo of a {}'
        imagenet_path = "/datasets/imagenet"
        val_dataset,all_class_names = get_imagenet_dataset(imagenet_path,args,preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers = args.num_workers, shuffle=False)
        class_names = []
        for a_name_set in all_class_names:
            tmp_str = ""
            for a_name in a_name_set:
                tmp_str+=a_name+", "
            tmp_str = tmp_str[0:-2]
            class_names.append(tmp_str)


        class_names = refine_classname(class_names)

    elif args.dataset == "stanford_cars":
        template = 'a photo of a {}'
        if args.VLM_Base == "RN50":
            val_dataset = Standfordcars_test_rn50("/datasets/stanfordCars/", transform=preprocess)
        elif args.VLM_Base == "ViT-B/32":
            val_dataset = Standfordcars_test_b32("/datasets/stanfordCars/")
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        class_names = val_dataset.classnames

    elif args.dataset == "sun397":
        template = 'a photo of a {}'
        if args.VLM_Base == "RN50":
            val_dataset = sun397_test_rn50("/datasets/sun397/", transform=preprocess)
        elif args.VLM_Base == "ViT-B/32":
            val_dataset = sun397_test_b32("/datasets/sun397/", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        class_names = val_dataset.classnames
        class_names = refine_classname(class_names)

    elif args.dataset == "DTD":
        template = '{} texture'
        val_dataset = DTD_test_dataset("/datasets/dtd", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        class_names = val_dataset.classnames

    elif args.dataset == "pets":
        template = 'a photo of a {}, a type of pet.'
        val_dataset = pets_test_dataset("/datasets/oxford-iiit-pet", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)

        class_names = val_dataset.classnames
        class_names = refine_classname(class_names)

    elif args.dataset == "Food101":
        template = 'a photo of a {}'
        val_dataset = food_test_dataset("/datasets/food-101", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)

        class_names = val_dataset.classnames
        class_names = refine_classname(class_names)

    elif args.dataset == "Flowers102":
        template = 'a photo of a {}, a type of flower.'
        val_dataset = flowers_test_dataset("/datasets/flowers102", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)

        class_names = val_dataset.classnames
        class_names = refine_classname(class_names)

    elif args.dataset == "eurosat":
        template = "a centered satellite photo of {}."
        val_dataset = eurosat_test("/datasets/eurosat", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        class_names = val_dataset.classnames
        class_names = refine_classname(class_names)


    elif args.dataset == "FGVCAircraft":
        template = 'a photo of a {}, a type of aircraft.'
        val_dataset = FGVCAircraft_test("/datasets/FGVCAircraft", transform=preprocess)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)

        class_names = val_dataset.classnames
        class_names = [c.replace("_", " ") for c in class_names]

    elif args.dataset == "esc50":
        template = "{}"
        val_dataset1 = ESC50(root="/datasets/ESC-50/",train=False, transform_audio=preprocess, fold=1)
        val_dataset2 = ESC50(root="/datasets/ESC-50/", train=False, transform_audio=preprocess, fold=2)
        val_dataset3 = ESC50(root="/datasets/ESC-50/", train=False, transform_audio=preprocess, fold=3)
        val_dataset4 = ESC50(root="/datasets/ESC-50/", train=False, transform_audio=preprocess, fold=4)
        val_dataset5 = ESC50(root="/datasets/ESC-50/", train=False, transform_audio=preprocess, fold=5)
        val_loader1 = DataLoader(val_dataset1, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False)
        val_loader2 = DataLoader(val_dataset2, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False)
        val_loader3 = DataLoader(val_dataset3, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False)
        val_loader4 = DataLoader(val_dataset4, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False)
        val_loader5 = DataLoader(val_dataset5, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False)
        val_loader = [val_loader1, val_loader2, val_loader3, val_loader4, val_loader5]

        sound_dict =val_dataset1.label_to_class_idx
        sorted_keys = sorted(sound_dict, key=lambda x: sound_dict[x])
        class_names = [key for key in sorted_keys]
        class_names = [c.replace("_", " ") for c in class_names]

    elif args.dataset == "urbansound8k":
        template = "{}"
        val_loader = []
        for fold_no in range(10):
            tmp_val_dataset = UrbanSound8K(root="/datasets/Urbansound8k/", train=False, transform_audio = preprocess, fold=fold_no+1)
            sound_dict = tmp_val_dataset.label_to_class_idx
            val_loader.append(DataLoader(tmp_val_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False))
        sorted_keys = sorted(sound_dict, key=lambda x: sound_dict[x])
        class_names = [key for key in sorted_keys]
        class_names = [c.replace("_", " ") for c in class_names]

    texts = [template.format(label) for label in class_names]
    if args.dataset == "esc50" or args.dataset == "urbansound8k":
        texts = [[template.format(label)] for label in class_names]
    print(texts)
    return val_loader, texts
