# n_way 1_shot meta training
from PIL import Image
from torch.utils.data.dataset import Dataset
import numpy as np
import pickle
import torch
import h5py
import os
import random
from .cut_mix import cutmix, mixup


class MetaTrainDataset(Dataset):
    def __init__(self,
                 dataset,
                 n_way,
                 n_shot,
                 dataset_path,
                 hdf5_path,
                 transform,
                 transform_method,
                 is_feature):

        super(MetaTrainDataset, self).__init__()
        assert dataset in ["ImageNet", "miniImageNet"]
        metadata_path = os.path.join(dataset_path, dataset + ".pkl")
        metadata_classwise_path = os.path.join(dataset_path, dataset + "_classwise.pkl")

        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f) # 所有图像的名称：图像_类名
        with open(metadata_classwise_path, 'rb') as f:
            self.metadata_classwise = pickle.load(f) # 按照类别分图像名称的字典：dict 类别名：图像路径
        self.classes = list(self.metadata_classwise.keys())
        self.transform = transform
        self.n_way = n_way
        self.n_support = n_shot
        self.transform_method = transform_method
        self.is_feature = is_feature

        if self.is_feature:
            self.meta_feat = []
            with h5py.File(os.path.join(hdf5_path, dataset + ".hdf5"), 'r') as f:
                self.meta_feat = f["complete"][...]

            self.classwise_feat = {}
            with h5py.File(os.path.join(hdf5_path, dataset + "_classwise.hdf5"), 'r') as f:
                for cl in self.classes:
                    self.classwise_feat[cl] = []
                    cl_data = f[cl][...]
                    self.classwise_feat[cl] = cl_data

    def __len__(self):
        return len(self.classes)

    def getimage(self, image_path, split=None):
        assert split in ['support', 'query']
        img = Image.open(image_path).convert('RGB')
        if self.transform:
            img = self.transform[split](img)
        return img

    def __getitem__(self, index):
        # choose classes
        chosen_cls = self.classes[index]

        if self.is_feature:
            support_feat_ids = np.random.choice(len(self.classwise_feat[chosen_cls]), self.n_support, replace=False)
            support_feats = [torch.Tensor(self.classwise_feat[chosen_cls][0][id]) for id in support_feat_ids]
            support_feats = torch.stack(support_feats)

            query_feat_ids = np.random.choice(len(self.classwise_feat[chosen_cls][1]), 1, replace=False)
            query_feats = [torch.Tensor(self.classwise_feat[chosen_cls][1][id]) for id in query_feat_ids]
            query_feats = torch.stack(query_feats)

            return support_feats, query_feats

        support_files = np.random.choice(self.metadata_classwise[chosen_cls], self.n_support, replace=False)
        query_files = np.random.choice(self.metadata_classwise[chosen_cls], 1)
        support_imgs = torch.stack([self.getimage(image_path, split='support') for image_path in support_files])
        query_imgs = torch.stack([self.getimage(image_path, split='query') for image_path in query_files])
        return support_imgs, query_imgs




