import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm


class MiniImagenetDataset(Dataset):
    def __init__(self, path, len, n_way, n_support, n_query, mode, seed):

        self.len = len

        rng = np.random.default_rng(seed)
        self.rng = rng

        num_images = 64
        img_subset = 400 # max 600
        if mode == "train":
            self.imgs = np.zeros((num_images, img_subset, 3, 84, 84))
            for i in tqdm(range(num_images)):
                self.imgs[i] = np.load(path + "/MiniImageNet/train_data/train_" + str(i) + ".npy")[0][0:img_subset]
        elif mode == "val":
            self.imgs = np.zeros((16, img_subset, 3, 84, 84))
            for i in range(16):
                self.imgs[i] = np.load(path + "/MiniImageNet/val_data/val_" + str(i) + ".npy")[0][0:img_subset]
        else:
            self.imgs = np.zeros((20, img_subset, 3, 84, 84))
            for i in range(20):
                self.imgs[i] = np.load(path + "/MiniImageNet/test_data/test_" + str(i) + ".npy")[0][0:img_subset]

        self.n_classes = self.imgs.shape[0]
        self.n_samples = self.imgs.shape[1]
        self.n_support = n_support
        self.n_query = n_query
        self.n_way = n_way

    def __len__(self):
        return self.len

    def __getitem__(self, idx):

        labels = self.rng.choice(np.arange(self.n_classes), self.n_way, replace=False)
        x_s = torch.zeros(self.n_way, self.n_support, 3, 84, 84)
        y_s = torch.zeros(self.n_way, self.n_support, 1).long()
        x_q = torch.zeros(self.n_way, self.n_query, 3, 84, 84)
        y_q = torch.zeros(self.n_way, self.n_query, 1).long()
        for i, l in enumerate(labels):
            idxs = self.rng.choice(np.arange(self.n_samples), self.n_support+self.n_query, replace=False)
            x_s[i] = torch.from_numpy(self.imgs[l, idxs[:self.n_support]]).float()
            y_s[i] = (torch.ones(self.n_support, 1) * i).long()
            x_q[i] = torch.from_numpy(self.imgs[l, idxs[self.n_support:]]).float()
            y_q[i] = (torch.ones(self.n_query, 1) * i).long()

        idxs_s = np.arange(self.n_way * self.n_support)
        self.rng.shuffle(idxs_s)
        idxs_q = np.arange(self.n_way * self.n_query)
        self.rng.shuffle(idxs_q)
        x_s = x_s.view(self.n_way * self.n_support, 3, 84, 84)[idxs_s]
        y_s = y_s.view(self.n_way * self.n_support)[idxs_s]
        x_q = x_q.view(self.n_way * self.n_query, 3, 84, 84)[idxs_q]
        y_q = y_q.view(self.n_way * self.n_query)[idxs_q]

        return x_s, y_s, x_q, y_q, torch.Tensor([1])


if __name__ == '__main__':

    dataset = MiniImagenetDataset("../data/", 10000, 5, 5, 50, "train")
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    for batch in dataloader:
        xs = batch[0]
        ys = batch[1]
        xq = batch[2]
        yq = batch[3]
