from torch.utils.data import Dataset
import pickle
from PIL import Image
import torch
import numpy as np
import h5py
import os


class TestDataset(Dataset):
    def __init__(self, dataset, dataset_path, transform, n_query, n_support, hdf5_file_path, backbone, is_feature=None):
        """
        :param dataset: test_dataset
        :param dataset_path: path to .pkl files
        :param transform: transformation functions
        :param n_query: how many query images per class
        :param n_support: how many support images per class
        :param hdf5_file_path: path to pre-extracted features
        :param is_feature: True when features are pre-extracted
        """

        dataset_pkl = os.path.join(dataset_path, dataset + ".pkl")
        with open(dataset_pkl, 'rb') as f:
            self.meta = pickle.load(f)
        delete_keys = []
        for cls in self.meta.keys():
            if len(self.meta[cls][0]) < 20 or len(self.meta[cls][1]) < 15:
                delete_keys.append(cls)
        for cls in delete_keys:
            del self.meta[cls]
        self.cl_list = list(self.meta.keys())
        self.n_query = n_query
        self.n_support = n_support
        self.transform = transform
        self.is_feature = is_feature

        if self.is_feature:
            hdf5_file = os.path.join(hdf5_file_path, backbone, 'test', dataset +".hdf5")
            self.meta_feat = {}
            with h5py.File(hdf5_file, 'r') as f:
                for cl in self.cl_list:
                    self.meta_feat[cl] = []

                    sup_feat = f[cl + "_support"][...]
                    que_feat = f[cl + "_query"][...]

                    self.meta_feat[cl].append(sup_feat)
                    self.meta_feat[cl].append(que_feat)

    def __len__(self):
        return len(self.cl_list)

    def getimage(self,image_path):
        img = Image.open(image_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

    def __getitem__(self, i):
        chosen_cls = self.cl_list[i]

        if self.is_feature:
            support_feat_ids = np.random.choice(len(self.meta_feat[chosen_cls][0]), self.n_support, replace=False)
            support_feats = [torch.Tensor(self.meta_feat[chosen_cls][0][id]) for id in support_feat_ids]
            support_feats = torch.stack(support_feats)

            query_feat_ids = np.random.choice(len(self.meta_feat[chosen_cls][1]), self.n_query, replace=False)
            query_feats = [torch.Tensor(self.meta_feat[chosen_cls][1][id]) for id in query_feat_ids]
            query_feats = torch.stack(query_feats)

            return support_feats, query_feats

        support_files = np.random.choice(self.meta[chosen_cls][0], self.n_support, replace=False)
        query_files = np.random.choice(self.meta[chosen_cls][1], self.n_query, replace=False)
        support_imgs = torch.stack([self.getimage(image_path) for image_path in support_files])
        query_imgs = torch.stack([self.getimage(image_path) for image_path in query_files])

        return support_imgs, query_imgs


class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes

    def __len__(self):
        return self.n_episodes

    def __iter__(self):
        for i in range(self.n_episodes):
            # sample n_way from total classes
            yield torch.randperm(self.n_classes)[:self.n_way]