import copy
import os
import pickle
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from collections import defaultdict

imgnet_mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
imgnet_std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
normalize = transforms.Normalize(mean=imgnet_mean, std=imgnet_std)

default_transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.ToTensor(),
                    normalize
                ])

aug_transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(84, padding=8),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.array(x),
                    transforms.ToTensor(),
                    normalize
                ])


class ImageNet(Dataset):
    def __init__(self, args, partition='train', transform=None):
        super(Dataset, self).__init__()
        self.data_root = args.data_root
        self.partition = partition
        self.data_aug = args.data_aug

        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = aug_transform
            else:
                self.transform = default_transform
        else:
            self.transform = transform

        if self.partition == "train":
            self.file_pattern = 'miniImageNet/miniImageNet_category_split_train_phase_%s.pickle'
        else:
            self.file_pattern = 'miniImageNet/miniImageNet_category_split_%s.pickle'
        print(os.path.join(self.data_root, self.file_pattern % partition))
        with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
            self.data = data['data']
            self.labels = data['labels']

    def __getitem__(self, item):
        img = np.asarray(self.data[item]).astype('uint8')
        img = self.transform(img)
        target = self.labels[item] - min(self.labels)
        return img, target, item
        
    def __len__(self):
        return len(self.labels)


class TieredImageNet(Dataset):
    def __init__(self, args, partition='train', transform=None):
        super(Dataset, self).__init__()
        self.data_root = args.data_root
        self.partition = partition
        self.data_aug = args.data_aug

        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = aug_transform
            else:
                self.transform = default_transform
        else:
            self.transform = transform

        self.image_file_pattern = 'tieredImageNet/%s_images.npz'
        self.label_file_pattern = 'tieredImageNet/%s_labels.pkl'

        # modified code to load tieredImageNet
        image_file = os.path.join(self.data_root, self.image_file_pattern % partition)
        self.data = np.load(image_file)['images']
        label_file = os.path.join(self.data_root, self.label_file_pattern % partition)
        self.labels = self._load_labels(label_file)['labels']

    def __getitem__(self, item):
        img = np.asarray(self.data[item]).astype('uint8')
        img = self.transform(img)
        target = self.labels[item] - min(self.labels)

        return img, target, item

    def __len__(self):
        return len(self.labels)

    @staticmethod
    def _load_labels(file):
        try:
            with open(file, 'rb') as fo:
                data = pickle.load(fo)
            return data
        except:
            with open(file, 'rb') as f:
                u = pickle._Unpickler(f)
                u.encoding = 'latin1'
                data = u.load()
            return data


class MetaDataset(Dataset):
    def __init__(self, flat_db, args, train_transform=None, test_transform=None, fix_seed=True, sample_shape=None, db_size=100,
                 output_idx=False, no_replacement=False):
        super(Dataset, self).__init__()
        self.data = flat_db.data
        self.labels = flat_db.labels
        self.fix_seed = fix_seed
        self.n_ways = args.n_ways
        self.n_shots = args.n_shots
        self.n_queries = args.n_queries
        self.db_size = db_size
        self.sample_shape = sample_shape if sample_shape else "few_shot"
        self.n_aug_support_samples = args.n_aug_support_samples
        self.local_cls = np.arange(self.n_ways)
        self.n_per_class = self.n_shots + self.n_queries
        self.output_idx = output_idx

        if train_transform is None:
            self.train_transform = flat_db.transform
        else:
            self.train_transform = train_transform

        if test_transform is None:
            self.test_transform = default_transform
        else:
            self.test_transform = test_transform

        self.label_to_data = defaultdict(list)
        for idx in range(self.data.shape[0]):
            self.label_to_data[self.labels[idx]].append(self.data[idx])

        for key in self.label_to_data:
            self.label_to_data[key] = np.asarray(self.label_to_data[key]).astype('uint8')

        self.classes = list(self.label_to_data.keys())

        if no_replacement:
            # if len(self.labels) % (self.n_ways * self.n_per_class) == 0:
            #     #only for mini-imagenet. the condition is not specific enough
            #     self._gen_task_without_replacement_optimal()
            # else:
            self._gen_task_without_replacement()
            self._get = self._fixed_item
        else:
            self._get = self._random_item

    def _gen_task_without_replacement_optimal(self):
        np.random.seed(1024)
        no_samples = len(self.labels)
        no_task = no_samples // (self.n_ways * self.n_per_class)
        all_samples = {}
        for key, val in self.label_to_data.items():
            tmp = np.arange(len(val))
            np.random.shuffle(tmp)
            all_samples[key] = tmp

        uniq_labels = np.unique(self.labels)
        cls_order = np.empty((0,))
        for _ in range(no_task*self.n_ways//len(uniq_labels)):
            tmp = copy.copy(uniq_labels)
            np.random.shuffle(tmp)
            if len(cls_order) % self.n_ways == 0:
                cls_order = np.append(cls_order, tmp)
            else:
                left_over = len(cls_order) % self.n_ways
                a = set(cls_order[-left_over:])
                for i in range(self.n_ways - left_over):
                    if tmp[i] in a:
                        for j in range(self.n_ways - left_over, self.n_ways):
                            if tmp[j] not in a:
                                tmp[i], tmp[j] = tmp[j], tmp[i]
                                break
                cls_order = np.append(cls_order, tmp)

        idx = []

        cls_order = cls_order.reshape(-1, self.n_ways)
        for task_id in range(no_task):
            cls_sampled = cls_order[task_id]

            uniq_idx = []
            for cls in cls_sampled:
                samples = all_samples[cls]
                ids_sampled, others = np.split(samples, [self.n_per_class])
                uniq_idx.append(ids_sampled)
                if len(others) > 0:
                    all_samples[cls] = others
                else:
                    del all_samples[cls]

            uniq_idx = np.asarray(uniq_idx)
            idx.append(uniq_idx)

        assert len(all_samples) == 0

        self.clss = cls_order
        self.sample_idx = idx
        self.db_size = no_task

    def _gen_task_without_replacement(self):
        np.random.seed(1024)
        all_samples = {}
        for key, val in self.label_to_data.items():
            tmp = np.arange(len(val))
            np.random.shuffle(tmp)
            all_samples[key] = tmp

        cls_order = []
        idx = []
        no_task = 0
        while len(all_samples) >= self.n_ways:
            no_task += 1
            cls_sampled = np.random.choice(list(all_samples.keys()), self.n_ways, False)
            cls_order.append(cls_sampled)

            uniq_idx = []
            for cls in cls_sampled:
                samples = all_samples[cls]
                ids_sampled, others = np.split(samples, [self.n_per_class])
                uniq_idx.append(ids_sampled)
                if len(others) >= self.n_per_class:
                    all_samples[cls] = others
                else:
                    del all_samples[cls]

            uniq_idx = np.asarray(uniq_idx)
            idx.append(uniq_idx)

        self.clss = cls_order
        self.sample_idx = idx
        self.db_size = no_task

    def _random_item(self, item):
        if self.fix_seed:
            np.random.seed(item)
        cls_sampled = np.random.choice(self.classes, self.n_ways, False)
        xs = []

        uniq_idx = []
        for idx, cls in enumerate(cls_sampled):
            samples = self.label_to_data[cls]
            ids_sampled = np.random.choice(range(samples.shape[0]), self.n_per_class, False)
            uniq_idx.append(ids_sampled)
            xs.append(samples[ids_sampled])

        xs = np.array(xs)

        return xs, cls_sampled, uniq_idx

    def _fixed_item(self, item):
        cls_sampled = self.clss[item]
        uniq_idx = self.sample_idx[item]

        xs = []

        for i, cls in enumerate(cls_sampled):
            samples = self.label_to_data[cls]
            ids_sampled = samples[uniq_idx[i]]
            xs.append(ids_sampled)

        xs = np.asarray(xs)
        return xs, cls_sampled, uniq_idx


    def __getitem__(self, item):
        xs, cls_sampled, uniq_idx = self._get(item)

        data_dims = xs.shape[2:]
        uniq_idx = np.asarray(uniq_idx)

        if self.sample_shape == "flat":
            xs = xs.reshape(-1, *data_dims)
            if self.output_idx:
                return self.batch_to_tensor(xs, self.train_transform), np.repeat(cls_sampled, self.n_per_class), cls_sampled, uniq_idx
            return self.batch_to_tensor(xs, self.train_transform), np.repeat(cls_sampled, self.n_per_class), cls_sampled
        elif self.sample_shape == "few_shot":
            support_xs, query_xs = np.split(xs, [self.n_shots], axis=1)

            support_xs = support_xs.reshape((-1, *data_dims))
            query_xs = query_xs.reshape((-1, *data_dims))

            support_xs = self.batch_to_tensor(support_xs, self.train_transform)
            query_xs = self.batch_to_tensor(query_xs, self.test_transform)

            return support_xs, np.repeat(self.local_cls, self.n_shots), query_xs, np.repeat(self.local_cls, self.n_queries)

    def __len__(self):
        return self.db_size

    @staticmethod
    def batch_to_tensor(arr, transform):
        arr = np.split(arr, arr.shape[0], axis=0)
        return torch.stack(list(map(lambda x: transform(x.squeeze()), arr)))


    
if __name__ == '__main__':
    args = lambda x: None
    args.n_ways = 5
    args.n_shots = 5
    args.n_queries = 15
    args.data_root = os.path.expanduser('~/workspace/metaL_data')
    args.data_aug = False
    args.n_test_runs = 5
    args.n_aug_support_samples = 1
    args.sample_shape = "few_shot"
    imagenet = ImageNet(args, 'train')


    meta_data = MetaDataset(imagenet, args, no_replacement=True)