import pickle
import random
import numpy as np

from .misc_utils import AttrDict


class MetaClassificationDataset(object):

    def __init__(self, path, n_training_tasks=0):
        self.path = path
        with open(path, 'rb') as fin:
            self._data = pickle.load(fin)

        if n_training_tasks > 0:
            self.data['train'] = self.data['train'][:n_training_tasks]

        if 'test' not in self.data:
            self.data['test'] = self.data['val']

    def sample_from_task(self, partition, adapt_batch_size, val_batch_size):
        raise NotImplementedError

    def sample_multiple_tasks(self, partition, n_tasks, adapt_batch_size, val_batch_size):
        return [
            self.sample_from_task(partition, adapt_batch_size, val_batch_size)
            for task_id in range(n_tasks)
        ]

    @property
    def data(self):
        return self._data

    @staticmethod
    def image_batch_shape(batch_size=None):
        raise NotImplementedError

    @staticmethod
    def n_logits():
        raise NotImplementedError


class RainbowMNISTDataset(MetaClassificationDataset):

    def sample_from_task(self, partition, adapt_batch_size, val_batch_size):
        assert partition in {'train', 'val', 'test'}

        task_id = np.random.choice(len(self.data[partition]))

        random_indices = np.random.choice(
            self.data[partition][task_id]['labels'].shape[0],
            adapt_batch_size + val_batch_size,
            replace=False
        )
        pre_adapt_indices = random_indices[:adapt_batch_size]
        post_adapt_indices = random_indices[adapt_batch_size:]

        pre_adapt_images = self.data[partition][task_id]['images'][pre_adapt_indices]
        pre_adapt_labels = self.data[partition][task_id]['labels'][pre_adapt_indices]
        post_adapt_images = self.data[partition][task_id]['images'][post_adapt_indices]
        post_adapt_labels = self.data[partition][task_id]['labels'][post_adapt_indices]

        return pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels

    @staticmethod
    def image_batch_shape(batch_size=None):
        return [batch_size, 28, 28, 3]

    @staticmethod
    def n_logits():
        return 10


class BinaryImageNetDataset(MetaClassificationDataset):

    def sampled_diff_images(self, partition, task_id, size, subsample_diff_tasks=10):
        subsampled_tasks = np.random.choice(
            [self.data[partition][i] for i in range(len(self.data[partition])) if i != task_id],
            size=min(subsample_diff_tasks, len(self.data[partition]) - 1),
            replace=False
        )

        task_sizes = np.random.multinomial(
            size, np.ones(len(subsampled_tasks)) / len(subsampled_tasks)
        )

        sampled_images = []
        for task, task_size in zip(subsampled_tasks, task_sizes):
            random_indices = np.random.choice(
                len(task['images']), task_size, replace=False
            )
            sampled_images.append(task['images'][random_indices])

        return np.concatenate(sampled_images, axis=0)

    def sample_from_task(self, partition, adapt_batch_size, val_batch_size):
        assert partition in {'train', 'val', 'test'}
        task_id = np.random.choice(len(self.data[partition]))

        adapt_same_size = int(0.5 * adapt_batch_size)
        val_same_size = int(0.5 * val_batch_size)
        adapt_diff_size = adapt_batch_size - adapt_same_size
        val_diff_size = val_batch_size - val_same_size

        all_indices = np.random.choice(
            self.data[partition][task_id]['images'].shape[0],
            int(adapt_batch_size + val_batch_size + adapt_same_size + val_same_size),
            replace=False
        )
        all_images = self.data[partition][task_id]['images'][all_indices]
        adapt_ref_images = all_images[:adapt_batch_size]
        adapt_same_images = all_images[adapt_batch_size:adapt_batch_size + adapt_same_size]
        val_ref_images = all_images[
            adapt_batch_size + adapt_same_size:adapt_batch_size + adapt_same_size + val_batch_size
        ]
        val_same_images = all_images[adapt_batch_size + adapt_same_size + val_batch_size:]

        all_diff_images = self.sampled_diff_images(
            partition, task_id, adapt_diff_size + val_diff_size
        )
        adapt_diff_images = all_diff_images[:adapt_diff_size]
        val_diff_images = all_diff_images[adapt_diff_size:]

        adapt_images = np.concatenate(
            [adapt_ref_images, np.concatenate([adapt_same_images, adapt_diff_images], axis=0)],
            axis=-1
        )
        val_images = np.concatenate(
            [val_ref_images, np.concatenate([val_same_images, val_diff_images], axis=0)],
            axis=-1
        )

        adapt_labels = np.concatenate(
            [np.ones(adapt_same_size, dtype=np.int64), np.zeros(adapt_diff_size, dtype=np.int64)]
        )
        val_labels = np.concatenate(
            [np.ones(val_same_size, dtype=np.int64), np.zeros(val_diff_size, dtype=np.int64)]
        )

        return adapt_images, adapt_labels, val_images, val_labels

    @staticmethod
    def image_batch_shape(batch_size=None):
        return [batch_size, 84, 84, 6]

    @staticmethod
    def n_logits():
        return 2


class OmniglotDataset(MetaClassificationDataset):

    def sample_from_task(self, partition, adapt_batch_size, val_batch_size):
        assert partition in {'train', 'val', 'test'}
        task_ids = np.random.choice(len(self.data[partition]), self.n_logits())

        pre_adapt_images = []
        pre_adapt_labels = []
        post_adapt_images = []
        post_adapt_labels = []

        for label, task_id in enumerate(task_ids):
            random_indices = np.random.choice(
                self.data[partition][task_id]['images'].shape[0],
                adapt_batch_size + val_batch_size,
                replace=False
            )
            pre_adapt_indices = random_indices[:adapt_batch_size]
            post_adapt_indices = random_indices[adapt_batch_size:]
            pre_adapt_images.append(
                self.data[partition][task_id]['images'][pre_adapt_indices]
            )
            pre_adapt_labels.append(np.full_like(pre_adapt_indices, label, dtype=np.int64))
            post_adapt_images.append(
                self.data[partition][task_id]['images'][post_adapt_indices]
            )
            post_adapt_labels.append(np.full_like(post_adapt_indices, label, dtype=np.int64))

        return (
            np.concatenate(pre_adapt_images), np.concatenate(pre_adapt_labels),
            np.concatenate(post_adapt_images), np.concatenate(post_adapt_labels)
        )

    @staticmethod
    def image_batch_shape(batch_size=None):
        return [batch_size, 28, 28, 1]


class Omniglot5WayDataset(OmniglotDataset):

    @staticmethod
    def n_logits():
        return 5


class Omniglot10WayDataset(OmniglotDataset):

    @staticmethod
    def n_logits():
        return 10


datasets = {
    'mnist': RainbowMNISTDataset,
    'imagenet': BinaryImageNetDataset,
    'omniglot5': Omniglot5WayDataset,
    'omniglot10': Omniglot10WayDataset,
}