from PIL import Image
from torch.utils.data.dataset import Dataset
import numpy as np
import pickle
import torch
import h5py
import os
import random
from .cut_mix import cutmix, mixup


class TrainDataset(Dataset):
    def __init__(self,
                 dataset,
                 dataset_path,
                 hdf5_path,
                 transform,
                 transform_method,
                 is_feature):

        super(TrainDataset, self).__init__()
        assert dataset in ["ImageNet", "miniImageNet", "CompleteDataset"]
        metadata_path = os.path.join(dataset_path, dataset + ".pkl")
        metadata_classwise_path = os.path.join(dataset_path, dataset + "_classwise.pkl")

        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f) # 所有图像的名称：图像_类名
        with open(metadata_classwise_path, 'rb') as f:
            self.metadata_classwise = pickle.load(f) # 按照类别分图像名称的字典：dict 类别名：图像路径
        self.classes = list(self.metadata_classwise.keys())
        self.transform = transform
        self.transform_method = transform_method
        self.is_feature = is_feature

        if self.is_feature:
            self.meta_feat = []
            with h5py.File(os.path.join(hdf5_path, dataset + ".hdf5"), 'r') as f:
                self.meta_feat = f["complete"][...]

            self.classwise_feat = {}
            with h5py.File(os.path.join(hdf5_path, dataset + "_classwise.hdf5"), 'r') as f:
                for cl in self.classes:
                    self.classwise_feat[cl] = []
                    cl_data = f[cl][...]
                    self.classwise_feat[cl] = cl_data

    def __len__(self):
        return len(self.metadata)

    def getimage(self, image_path, split=None):
        assert split in ['support', 'query']
        img = Image.open(image_path).convert('RGB')
        if self.transform:
            img = self.transform[split](img)
        return img

    def __getitem__(self, index):
        # 保证正样本和负样本数量平衡，根据index的奇偶性判断，如果index是偶数，则为正样本，如果index是奇数，则为负样本
        query_img_path = self.metadata[index]
        query_class = query_img_path.split('/')[4]
        query_img = self.getimage(query_img_path, "query")

        # mixup and cutmix
        """
        p = random.randint(1, 10)
        if p > 5:
            mix_class = np.random.choice(self.classes, 1).item()
            while mix_class == query_class:
                mix_class = np.random.choice(self.classes, 1).item()
            mix_img_path = np.random.choice(self.metadata_classwise[mix_class], 1).item()
            mix_img = self.getimage(mix_img_path, "query")
            query_img = random.choice([cutmix, mixup])(query_img.clone(), mix_img)
        """

        if self.is_feature:
            query_feat = self.meta_feat[index]

        if index % 2 == 0: # positive
            support_class = query_class
            support_img_path = np.random.choice(self.metadata_classwise[support_class], 1).item()
            while support_img_path == query_img_path:
                support_img_path = np.random.choice(self.metadata_classwise[support_class], 1).item()

        else: # negative
            support_class = np.random.choice(self.classes, 1).item()
            while support_class == query_class:
                support_class = np.random.choice(self.classes, 1).item()
            support_img_path = np.random.choice(self.metadata_classwise[support_class], 1).item()

        support_img = self.getimage(support_img_path, "support")

        if self.is_feature:
            support_img_id = self.metadata_classwise[support_class].index(support_img_path)
            support_feat = self.classwise_feat[support_class][support_img_id]

        if support_class == query_class:
            label = torch.tensor(1)
        else:
            label = torch.tensor(0)

        if self.is_feature:
            return support_feat, query_feat, label

        return support_img, query_img, label




