import os
import torch
from torch.utils.data import Dataset
import pickle
from torchvision.transforms import transforms
import numpy as np
import collections
from PIL import Image
import csv
import random


class MetaMiniImagenet(Dataset):
    """
    NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set.
    batch: contains several sets
    sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set.
    """

    def __init__(self,
                 dataset,
                 batchsz,
                 n_way,
                 k_shot,
                 dataset_path,
                 hdf5_path,
                 transform,
                 transform_method,
                 is_feature):
        super(MetaMiniImagenet, self).__init__()
        self.batchsz = batchsz  # batch of set, not batch of imgs
        self.n_way = n_way
        self.k_shot = k_shot
        self.setsz = self.n_way * self.k_shot  # num of samples per set
        self.querysz = self.n_way * 1

        assert dataset in ["ImageNet", "miniImageNet"]
        metadata_classwise_path = os.path.join(dataset_path, dataset + "_classwise.pkl")

        with open(metadata_classwise_path, 'rb') as f:
            self.metadata_classwise = pickle.load(f)  # 按照类别分图像名称的字典：dict 类别名：图像路径
        self.classes = list(self.metadata_classwise.keys())
        
        self.create_batch(batchsz)

    def create_batch(self, batchsz):
        """
        create batch for meta-learning.
        ×episode× here means batch, and it means how many sets we want to retain.
        :param episodes: batch size
        :return:
        """
        self.support_x_batch = []  # support set batch
        self.query_x_batch = []  # query set batch
        for b in range(batchsz):
            # 1.select n_way classes randomly
            selected_cls = np.random.choice(self.classes, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                # 2. select k_shot + 1 for each class
                selected_imgs_idx = np.random.choice(len(self.metadata_classwise[cls]), self.k_shot + 1, False)
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[:self.k_shot])  # idx for Dtrain
                indexDtest = np.array(selected_imgs_idx[self.k_shot:])  # idx for Dtest
                support_x.append(
                    np.array(self.metadata_classwise[cls])[indexDtrain].tolist())  # get all images filename for current Dtrain
                query_x.append(np.array(self.metadata_classwise[cls])[indexDtest].tolist())

            # shuffle the correponding relation between support set and query set
            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)  # append set to current sets
            self.query_x_batch.append(query_x)  # append sets to current sets

    def __getitem__(self, index):
        """
        index means index of sets, 0<= index <= batchsz-1
        :param index:
        :return:
        """
        # [setsz, 3, resize, resize]
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        # [setsz]
        support_y = np.zeros((self.setsz), dtype=np.int)
        # [querysz, 3, resize, resize]
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        # [querysz]
        query_y = np.zeros((self.querysz), dtype=np.int)

        flatten_support_x = [os.path.join(self.path, item)
                             for sublist in self.support_x_batch[index] for item in sublist]
        support_y = np.array(
            [self.img2label[item[:9]]  # filename:n0153282900000005.jpg, the first 9 characters treated as label
             for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)

        flatten_query_x = [os.path.join(self.path, item)
                           for sublist in self.query_x_batch[index] for item in sublist]
        query_y = np.array([self.img2label[item[:9]]
                            for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)

        # print('global:', support_y, query_y)
        # support_y: [setsz]
        # query_y: [querysz]
        # unique: [n-way], sorted
        unique = np.unique(support_y)
        random.shuffle(unique)
        # relative means the label ranges from 0 to n-way
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        # print('relative:', support_y_relative, query_y_relative)

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)
        # print(support_set_y)
        # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

        return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.batchsz
        