import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.transforms import transforms
import os


class MiniImagenetDataset(Dataset):
    def __init__(self, len, n_way, n_support, n_query, mode, device):

        self.len = len

        rng = np.random.default_rng(0)
        self.rng = rng

        num_images_tr = 64
        num_images_val = 16

        #img_subset = 600 #400 # max 600
        img_subset = 200 #400 # max 600

        self.transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

        if os.path.exists(f'imgs_{mode}.npy'):
            self.imgs = np.load(f'imgs_{mode}.npy')
        else:
            if mode == "train":
                self.imgs = np.zeros((num_images_tr, img_subset, 3, 84, 84))
                for i in tqdm(range(num_images_tr)):
                    self.imgs[i] = np.load("./data/MiniImageNet/train_data/train_" + str(i) + ".npy")[0][0:img_subset]
            elif mode == "val":
                self.imgs = np.zeros((16, img_subset, 3, 84, 84))
                for i in range(16):
                    self.imgs[64+i] = np.load("./data/MiniImageNet/val_data/val_" + str(i) + ".npy")[0][0:img_subset]
            else:
                self.imgs = np.zeros((20, img_subset, 3, 84, 84))
                for i in range(20):
                    self.imgs[i] = np.load("./data/MiniImageNet/test_data/test_" + str(i) + ".npy")[0][0:img_subset]

            np.save(f'imgs_{mode}.npy', self.imgs)
            print("Saved")

        self.n_classes = self.imgs.shape[0]
        self.n_samples = self.imgs.shape[1]
        self.n_support = n_support
        self.n_query = n_query
        self.n_way = n_way

        self.device = device

    def __len__(self):
        return self.len

    def __getitem__(self, idx):

        labels = self.rng.choice(np.arange(self.n_classes), self.n_way, replace=False)
        x_s = torch.zeros(self.n_way, self.n_support, 3, 84, 84)
        y_s = torch.zeros(self.n_way, self.n_support, 1).long()
        x_q = torch.zeros(self.n_way, self.n_query, 3, 84, 84)
        y_q = torch.zeros(self.n_way, self.n_query, 1).long()
        for i, l in enumerate(labels):
            idxs = self.rng.choice(np.arange(self.n_samples), self.n_support+self.n_query, replace=False)
            x_s[i] = torch.from_numpy(self.imgs[l, idxs[:self.n_support]]).float()
            y_s[i] = (torch.ones(self.n_support, 1) * i).long()
            x_q[i] = torch.from_numpy(self.imgs[l, idxs[self.n_support:]]).float()
            y_q[i] = (torch.ones(self.n_query, 1) * i).long()

        idxs_s = np.arange(self.n_way * self.n_support)
        self.rng.shuffle(idxs_s)
        idxs_q = np.arange(self.n_way * self.n_query)
        self.rng.shuffle(idxs_q)
        

        x_s = x_s.view(self.n_way * self.n_support, 3, 84, 84)[idxs_s]
        y_s = y_s.view(self.n_way * self.n_support)[idxs_s]
        x_q = x_q.view(self.n_way * self.n_query, 3, 84, 84)[idxs_q]
        y_q = y_q.view(self.n_way * self.n_query)[idxs_q]

        x_s = self.transform(x_s)
        x_q = self.transform(x_q)

        return x_s.to(self.device), y_s.to(self.device), x_q.to(self.device), y_q.to(self.device)


if __name__ == '__main__':

    dataset = MiniImagenetDataset(10000, 5, 10, 50, "train", "cpu")
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    for batch in dataloader:
        xs = batch[0]
        ys = batch[1]
        xq = batch[2]
        yq = batch[3]
