import numpy as np
import time
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms

def get_transform(
    img_size, random_crop=False, random_horizontal_flip=False, normalize_mean=(0.5,), normalize_std=(0.5,)
):
    transform_list = []
    if random_crop:
        transform_list.append(transforms.RandomResizedCrop((img_size, img_size)))
    else:
        transform_list.append(transforms.Resize((img_size, img_size)))
    if random_horizontal_flip:
        transform_list.append(transforms.RandomHorizontalFlip())
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize(normalize_mean, normalize_std))
    return transforms.Compose(transform_list)

def get_dataset(data_name, meta_level, img_size):
    if data_name == "cifar":
        return EpisodicDataset(f"/st2/anonymous/data/meta_cifar_100/{meta_level}", img_size)
    elif data_name == "tiny":
        return EpisodicDataset(f"/st2/anonymous/data/meta_TIN/{meta_level}", img_size)
    else:
        raise NotImplementedError

class EpisodicDataset:
    def __init__(self, path, img_size):

        self.path = path

        train_images = np.load(f"{path}/train_image.npy", allow_pickle=True)
        train_labels = np.load(f"{path}/train_label.npy", allow_pickle=True)
        train_class_idx = np.load(f"{path}/train_class_idx.npy", allow_pickle=True)

        test_images = np.load(f"{path}/test_image.npy", allow_pickle=True)
        test_labels = np.load(f"{path}/test_label.npy", allow_pickle=True)
        test_class_idx = np.load(f"{path}/test_class_idx.npy", allow_pickle=True)

        if "CUB" in path or "cars" in path or "aircraft" in path:
            test_class_idx = np.array(
                    [[c + len(train_images) for c in teidx] for teidx in test_class_idx], dtype=object)
        else:
            test_class_idx += len(train_images)

        self.images = np.concatenate([train_images, test_images], axis=0)
        self.labels = np.concatenate([train_labels, test_labels], axis=0)

        if "CUB" in path or "cars" in path or "aircraft" in path:
            self.class_idx = np.array(
                    [ tridx + teidx for tridx, teidx in zip(train_class_idx, test_class_idx)], dtype=object)
        else:
            self.class_idx = np.concatenate([train_class_idx, test_class_idx], axis=1)

        self.train_transform = get_transform(img_size, random_crop=True, random_horizontal_flip=True)
        self.test_transform = get_transform(img_size, random_crop=False, random_horizontal_flip=False)

    def get_task(self, rank, num_classes, batch_size, num_workers=0, pin_memory=False, drop_last=True):

        seed = (rank+1) * round(time.time() % 1000)

        self.train_ds = NumpyDataset(
            self.path,
            self.images,
            self.labels,
            self.class_idx,
            num_classes,
            seed,
            train=True,
            transform=self.train_transform
        )
        self.test_ds = NumpyDataset(
            self.path,
            self.images,
            self.labels,
            self.class_idx,
            num_classes,
            seed,
            train=False,
            transform=self.test_transform
        )
        train_dataloader = DataLoader(
            self.train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
        )
        test_dataloader = DataLoader(
            self.test_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
        )
        return train_dataloader, test_dataloader

    def get_data_by_index(self, index, train=True):
        ds = self.train_ds if train else self.test_ds
        batched_data = None
        for i, idx in enumerate(index):
            if batched_data is None:
                batched_data = tuple(
                    torch.zeros((len(index),) + (d.size() if type(d) == torch.Tensor else (1,))) for d in ds[idx]
                )
            for bd, d in zip(batched_data, ds[idx]):
                bd[i] = d
        return batched_data

class NumpyDataset(Dataset):
    def __init__(self, path, images, labels, class_idx, num_classes, seed, train=True, transform=None):
        super().__init__()

        np.random.seed(seed)

        # randomly sample classes
        C = len(class_idx)
        cidx = np.random.permutation(C)[:num_classes]
        if "CUB" in path or "cars" in path or "aircraft" in path:
            flatten_list = []
            selected_class_idx = class_idx[cidx]
            for sci in selected_class_idx:
                flatten_list = flatten_list + sci
            idx = np.array(flatten_list)
        else:
            idx = class_idx[cidx].reshape(-1)
        self.images, self.labels = images[idx], labels[idx]

        # randomly sample examples
        N = len(self.images)
        eidx = np.random.permutation(N)
        self.images, self.labels = self.images[eidx], self.labels[eidx]
        n = int(N/2)
        if train:
            self.images, self.labels = self.images[:n], self.labels[:n]
        else:
            self.images, self.labels = self.images[n:], self.labels[n:]

        # transforms
        def target_transform(y):
            if y not in list(cidx):
                print(y, cidx)
            return list(cidx).index(y)
        self.transform = transform
        self.target_transform = target_transform
        self.length = len(self.images)

    def __getitem__(self, index):
        img = Image.fromarray(self.images[index])
        label = self.labels[index]
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        return img, label, index

    def __len__(self):
        return self.length
