import os
import pickle
import numpy as np

from .utils import *

from .oxford_pets import OxfordPets
from .dtd import DescribableTextures as DTD
from torchvision import transforms

class CIFAR100(DatasetBase):
    dataset_dir = "cifar100"

    def __init__(self, root, num_shots=0, seed=1, subsample_classes='all'):
        root = os.path.abspath(os.path.expanduser(root))
        self.dataset_dir = os.path.join(root, self.dataset_dir)

        file_path = os.path.join(root, self.dataset_dir, 'train')
        with open(file_path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")
            trainval_data = entry["data"]
            trainval_targets = entry["fine_labels"] if "fine_labels" in entry else entry["labels"]
        trainval_data = trainval_data.reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))  # (N, H, W, C)

        file_path = os.path.join(root, self.dataset_dir, 'test')
        with open(file_path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")
            test_data = entry["data"]
            test_targets = entry["fine_labels"] if "fine_labels" in entry else entry["labels"]
        test_data = test_data.reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))  # (N, H, W, C)

        path = os.path.join(self.dataset_dir, "meta")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            classes = [s.replace("_", " ") for s in data["fine_label_names"]]

        trainval = [Datum(impath=Image.fromarray(trainval_data[idx]),  
                          label=int(trainval_targets[idx]),
                          classname=classes[trainval_targets[idx]]) for idx in range(len(trainval_data))]

        test = [Datum(impath=Image.fromarray(test_data[idx]),  
                      label=int(test_targets[idx]),
                      classname=classes[test_targets[idx]]) for idx in range(len(test_data))]

        train, val = OxfordPets.split_trainval(trainval)

        if num_shots >= 1:
            train = self.generate_fewshot_dataset(train, num_shots=num_shots)
            val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))

        train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample_classes)

        self.templates = [
            lambda c: f'a photo of a {c}.',
            lambda c: f'a blurry photo of a {c}.',
            lambda c: f'a black and white photo of a {c}.',
            lambda c: f'a low contrast photo of a {c}.',
            lambda c: f'a high contrast photo of a {c}.',
            lambda c: f'a bad photo of a {c}.',
            lambda c: f'a good photo of a {c}.',
            lambda c: f'a photo of a small {c}.',
            lambda c: f'a photo of a big {c}.',
            lambda c: f'a photo of the {c}.',
            lambda c: f'a blurry photo of the {c}.',
            lambda c: f'a black and white photo of the {c}.',
            lambda c: f'a low contrast photo of the {c}.',
            lambda c: f'a high contrast photo of the {c}.',
            lambda c: f'a bad photo of the {c}.',
            lambda c: f'a good photo of the {c}.',
            lambda c: f'a photo of the small {c}.',
            lambda c: f'a photo of the big {c}.',
        ]

        super().__init__(train_x=train, val=val, test=test)