# Sinusoid tasks as described in the original MAML paper
# each task has a different phase in [0, pi] and a different amplitude in [0.1, 5.0]
# the points for each task are sampled uniformly in [-5, 5]
# we sample both the training and test points
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
from hydra.utils import to_absolute_path
import numpy as np
from PIL import Image


@partial(jax.jit, static_argnames=("meta_batch_size", "inner_batch_size_train", "inner_batch_size_eval", "linspace_eval"))
def sample_task_batch(key, meta_batch_size=25, inner_batch_size_train=10, inner_batch_size_eval=10, linspace_eval=False):
    """Sample a batch of tasks.


    Returns:
        training_task_batch_train (List): List of input and output of the training task batch.
            It has dimensions 2 x `meta_batch_size` x `inner_batch_size_train` x 1.
        training_task_batch_test (List): List of input and output of the test task batch.
            It has dimensions 2 x `meta_batch_size` x `inner_batch_size_eval` x 1.
        key (jax.random.PRNGKey): New key.
    """
    training_task_batch_train, training_task_batch_test = [], []
    key, phase_key, amp_key, train_key, test_key = jax.random.split(key, 5)
    phase = jax.random.uniform(phase_key, shape=(meta_batch_size,), minval=0.0, maxval=jnp.pi)
    amplitude = jax.random.uniform(amp_key, shape=(meta_batch_size,), minval=0.1, maxval=5.0)
    training_task_batch_train = sample_tasks(train_key, meta_batch_size, phase, amplitude, inner_batch_size_train)
    training_task_batch_test = sample_tasks(test_key, meta_batch_size, phase, amplitude, inner_batch_size_eval, linspace_eval)
    return training_task_batch_train, training_task_batch_test, key

def sample_tasks(key, meta_batch_size, phase, amplitude, inner_batch_size=10, linspace_eval=False):
    """Sample a task.

    Args:
        key (jax.random.PRNGKey): Key.
        meta_batch_size (int): Batch size.
        phase (float): Phase.
        amplitude (float): Amplitude.
        inner_batch_size (int, optional): Batch size. Defaults to 10.

    Returns:
        task_input (jax.numpy.ndarray): Input of the task.
        task_output (jax.numpy.ndarray): Output of the task.
    """
    if linspace_eval:
        task_input = jnp.repeat(jnp.linspace(-5.0, 5.0, inner_batch_size)[None, :, None], meta_batch_size, axis=0)
    else:
        task_input = jax.random.uniform(key, shape=(meta_batch_size, inner_batch_size, 1), minval=-5.0, maxval=5.0)
    task_output = amplitude[:, None, None] * jnp.sin(task_input - phase[:, None, None])
    return jnp.array((task_input, task_output))


def rot90_traceable(m, k=1, axes=(0, 1)):
    # from https://github.com/google/jax/issues/55#issuecomment-1241661455
    k %= 4
    return jax.lax.switch(k, [partial(jnp.rot90, m, k=i, axes=axes) for i in range(4)])


class OmniglotDataset:
    """Omniglot dataset for few-shot learning.

    Args:
        eval_characters_dir (List, optional): List of characters to use for evaluation. Defaults to None.
        data_dir (str, optional): Path to the Omniglot dataset. Defaults to "omniglot_resized/all_images".
        eval_prop (float, optional): Proportion of characters to use for evaluation. Defaults to 1/16.
        seed (int, optional): Seed for the random number generator. Defaults to 0.
        inner_batch_size_train (int, optional): Batch size for the training set. Defaults to 10.
        inner_batch_size_eval (int, optional): Batch size for the test set. Defaults to 10.
        k_ways (int, optional): Number of classes in the classification problem. Defaults to 5.
        image_size (int, optional): Size of the images. Defaults to 28.
        rotation_augmentation (bool, optional): Whether to use rotation augmentation. Defaults to False.
    """
    def __init__(
            self,
            eval_characters_dir=None,
            data_dir="omniglot_resized/all_images",
            eval_prop=1/16,
            seed=0,
            inner_batch_size_train=10,
            inner_batch_size_eval=10,
            k_ways=5,
            image_size=28,
            rotation_augmentation=False,
        ):
        # How does Omniglot for few-shot learning work?
        # You have a set of alphabet, each of which has a set of characters.
        # Each characters has been hand drawn by different people, 20 times for each character.
        # Each task for few-shot learning is defined by taking k_ways characters from all the
        # alphabets and creating a multi-class classification problem.
        # Inside a task, each character has some images in the training set and some images in the
        # test set. The images are resized to 28x28 pixels.
        # In order to train MAML/iMAML, we need to first select characters that will define
        # some tasks for the meta-training.
        self.eval_characters_dir = eval_characters_dir
        self.data_dir = Path(to_absolute_path(data_dir))
        self.languages_dir = sorted([p for p in self.data_dir.glob("*") if p.is_dir()])
        # characters are the subdirectories of the languages
        self.characters_dir = sorted([char for lang in self.languages_dir for char in lang.glob("*") if char.is_dir()])
        self.num_characters = len(self.characters_dir)
        if self.eval_characters_dir is None:
            eval_idx = np.random.RandomState(seed).choice(self.num_characters, int(self.num_characters * eval_prop), replace=False)
            self.eval_characters_dir = sorted([self.characters_dir[i] for i in eval_idx])
        self.train_characters_dir = sorted([l for l in self.characters_dir if l not in self.eval_characters_dir])
        self.train_characters = None
        self.train_labels = None
        self.eval_characters = None
        self.eval_labels = None
        self.num_train_characters = len(self.train_characters_dir)
        self.num_eval_characters_dir = len(self.eval_characters_dir)
        self.inner_batch_size_train = inner_batch_size_train
        self.inner_batch_size_eval = inner_batch_size_eval
        assert self.inner_batch_size_train + self.inner_batch_size_eval == 20
        self.k_ways = k_ways
        self.image_size = image_size
        self.rotation_augmentation = rotation_augmentation


    def load_characters_labels(self, train=True):
        if train:
            if self.train_characters is not None:
                return self.train_characters
            characters_dir = self.train_characters_dir
        else:
            if self.eval_characters is not None:
                return self.eval_characters
            characters_dir = self.eval_characters_dir
        characters = []
        for char_dir in characters_dir:
            # each character has 20 images
            images = sorted([p for p in char_dir.glob("*.png") if p.is_file()])
            assert len(images) == 20
            images_arr = jnp.stack([jnp.array(Image.open(p).convert("L")) for p in images]) / 255.0
            characters.append(images_arr)
        characters = jnp.stack(characters)
        # one hot encoding of the labels
        labels = jnp.eye(len(characters_dir))
        if train:
            self.train_characters = characters
            self.train_labels = labels
        else:
            self.eval_characters = characters
            self.eval_labels = labels
        return characters


    @partial(jax.jit, static_argnames=("self", "meta_batch_size", "train"))
    def sample_tasks(self, key, meta_batch_size=25, train=True):
        """Sample a batch of tasks for meta-training or meta-evaluation.

        Use `load_characters_labels` to load the characters and labels before calling this method.

        Args:
            key (jax.random.PRNGKey): Jax PRNG key.
            meta_batch_size (int, optional): Number of tasks to sample. Defaults to 25.
            train (bool, optional): Whether to sample tasks for meta-training or meta-evaluation.
                Defaults to True.

        Returns:
            tuple: A tuple containing:
                - `tasks_inputs_train` (jax.interpreters.xla.DeviceArray): A tensor of shape
                    (`meta_batch_size`, `inner_batch_size_train`, `self.image_size`, `self.image_size`, 1)
                    containing the inputs for the training set of each task.
                - `tasks_outputs_train` (jax.interpreters.xla.DeviceArray): A tensor of shape
                    (`meta_batch_size`, `inner_batch_size_train`, `self.k_ways`) containing the
                    one-hot encoded labels for the training set of each task.
                - `tasks_inputs_eval` (jax.interpreters.xla.DeviceArray): A tensor of shape
                    (`meta_batch_size`, `inner_batch_size_eval`, `self.image_size`, `self.image_size`, 1)
                    containing the inputs for the evaluation set of each task.
                - `tasks_outputs_eval` (jax.interpreters.xla.DeviceArray): A tensor of shape
                    (`meta_batch_size`, `inner_batch_size_eval`, `self.k_ways`) containing the
                    one-hot encoded labels for the evaluation set of each task.

        """
        # in order to sample some tasks, we need to look only at the
        # train_characters_dir.
        # each task will consist in selecting a certain number of characters
        # then for each of the selected characters, splitting its images
        # into a training set and a test set.
        if train:
            characters, labels = self.train_characters, self.train_labels
            n_characters = self.num_train_characters
        else:
            characters, labels = self.eval_characters, self.eval_labels
            n_characters = self.num_eval_characters_dir
        tasks_inputs_train, tasks_outputs_train, tasks_inputs_eval, tasks_outputs_eval = [], [], [], []
        subkeys = jax.random.split(key, meta_batch_size)
        for subkey in subkeys:
            # let's select k_ways characters
            selected_characters = jax.random.choice(subkey, n_characters, shape=(self.k_ways,), replace=False)
            task_characters = characters[selected_characters]
            task_labels = labels[selected_characters][:, selected_characters]
            # now we need to split the images of each character into a training set and a test set
            shuffled_characters = jax.vmap(jax.random.shuffle)(
                jax.random.split(subkey, self.k_ways),
                task_characters,
            )
            if self.rotation_augmentation:
                # with this implementation,
                # you will never get in the same batch the same character with
                # different rotations
                rotation = jax.random.randint(subkey, (), minval=0, maxval=4)
                shuffled_characters = rot90_traceable(shuffled_characters, rotation, axes=(-2, -1))
            train_characters = shuffled_characters[:, :self.inner_batch_size_train]
            eval_characters = shuffled_characters[:, self.inner_batch_size_train:]
            # flatten characters, and associate corresponding labels
            train_characters = train_characters.reshape(-1, self.image_size, self.image_size, 1)
            train_labels = jnp.repeat(task_labels, self.inner_batch_size_train, axis=0)
            eval_characters = eval_characters.reshape(-1, self.image_size, self.image_size, 1)
            eval_labels = jnp.repeat(task_labels, self.inner_batch_size_eval, axis=0)
            # let's append all of this
            tasks_inputs_train.append(train_characters)
            tasks_outputs_train.append(train_labels)
            tasks_inputs_eval.append(eval_characters)
            tasks_outputs_eval.append(eval_labels)
        tasks_inputs_train = jnp.stack(tasks_inputs_train)
        tasks_outputs_train = jnp.stack(tasks_outputs_train)
        tasks_inputs_eval = jnp.stack(tasks_inputs_eval)
        tasks_outputs_eval = jnp.stack(tasks_outputs_eval)
        return tasks_inputs_train, tasks_outputs_train, tasks_inputs_eval, tasks_outputs_eval
