import os
import glob
import random
from collections import defaultdict

import torch
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.datasets.utils import list_files

from scripts.sampler import ClassBalancedSampler
from scripts.datasets.metadataset import Task


class OmniglotMAMLSplit(Omniglot):
    """Implements similar train / test split for Omniglot as
    https://github.com/cbfinn/maml/blob/master/data_generator.py

    Uses torchvision.datasets.Omniglot for downloading and checking
    dataset integrity.
    """
    def __init__(self, root, train=True, num_train_classes=1100, **kwargs):
        super(OmniglotMAMLSplit, self).__init__(root, download=True,
                                                background=True, **kwargs)

        self._train = train
        self._num_train_classes = num_train_classes

        # download testing data and test integrity
        self.background = False
        self.download()
        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted')

        all_character_dirs = glob.glob(self.root + '/**/**/**')
        if self._train:
            self._characters = all_character_dirs[:self._num_train_classes]
        else:
            self._characters = all_character_dirs[self._num_train_classes:]

        self._character_images = []
        for i, char_path in enumerate(self._characters):
            img_list = [(cp, i) for cp in glob.glob(char_path + '/*')]
            self._character_images.append(img_list)

        self._flat_character_images = sum(self._character_images, [])

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target
            character class.
        """
        image_path, character_class = self._flat_character_images[index]
        image = Image.open(image_path, mode='r').convert('L')

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            character_class = self.target_transform(character_class)

        return image, character_class


class OmniglotMetaDataset(object):
    """
    TODO: Check if the data loader is fast enough.
    Args:
        root: path to omniglot dataset
        img_side_len: images are scaled to this size
        num_classes_per_batch: number of classes to sample for each batch
        num_samples_per_class: number of samples to sample for each class
            for each batch. For K shot learning this should be K + number
            of validation samples
        num_total_batches: total number of tasks to generate
        train: whether to create data loader from the test or validation data
    """
    def __init__(self, name='Omniglot', root='data', 
                 img_side_len=28, img_channel=1,
                 num_classes_per_batch=5, num_samples_per_class=6, 
                 num_total_batches=200000,
                 num_val_samples=1, meta_batch_size=40, train=True,
                 num_train_classes=1100, num_workers=0, device='cpu'):
        self.name = name
        self._root = root
        self._img_side_len = img_side_len
        self._img_channel = img_channel
        self._num_classes_per_batch = num_classes_per_batch
        self._num_samples_per_class = num_samples_per_class
        self._num_total_batches = num_total_batches
        self._num_val_samples = num_val_samples
        self._meta_batch_size = meta_batch_size
        self._num_train_classes = num_train_classes
        self._train = train
        self._num_workers = num_workers
        self._device = device

        self._total_samples_per_class = (
            num_samples_per_class + num_val_samples)
        self._dataloader = self._get_omniglot_data_loader()

        self.input_size = (img_channel, img_side_len, img_side_len)
        self.output_size = self._num_classes_per_batch

    def _get_omniglot_data_loader(self):
        assert self._img_channel == 1 or self._img_channel == 3
        resize = transforms.Resize(self._img_side_len, Image.LANCZOS)
        invert = transforms.Lambda(lambda x: 1.0 - x)
        if self._img_channel > 1:
            # tile the image
            tile = transforms.Lambda(lambda x: x.repeat(self._img_channel, 1, 1))
            img_transform = transforms.Compose(
                [resize, transforms.ToTensor(), invert, tile])
        else:
            img_transform = transforms.Compose(
                [resize, transforms.ToTensor(), invert])
        dset = OmniglotMAMLSplit(self._root, transform=img_transform,
                                 train=self._train,
                                 num_train_classes=self._num_train_classes)
        _, labels = zip(*dset._flat_character_images)
        sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
                                       self._total_samples_per_class,
                                       self._num_total_batches, self._train)

        batch_size = (self._num_classes_per_batch
                      * self._total_samples_per_class
                      * self._meta_batch_size)
        loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
                            num_workers=self._num_workers, pin_memory=True)
        return loader

    def _make_single_batch(self, imgs, labels):
        """Split imgs and labels into train and validation set.
        TODO: check if this might become the bottleneck"""
        # relabel classes randomly
        new_labels = list(range(self._num_classes_per_batch))
        random.shuffle(new_labels)
        labels = labels.tolist()
        label_set = set(labels)
        label_map = {label: i for i, label in zip(new_labels, label_set)}
        labels = [label_map[l] for l in labels]

        label_indices = defaultdict(list)
        for i, label in enumerate(labels):
            label_indices[label].append(i)

        # rotate randomly to create new classes
        # TODO: move this to torch once supported.
        for label, indices in label_indices.items():
            rotation = np.random.randint(4)
            for i in range(len(indices)):
                img = imgs[indices[i]].numpy()
                # copy here for contiguity
                img = np.copy(np.rot90(img, k=rotation, axes=(1,2)))
                imgs[indices[i]] = torch.from_numpy(img)

        # assign samples to train and validation sets
        val_indices = []
        train_indices = []
        for label, indices in label_indices.items():
            val_indices.extend(indices[:self._num_val_samples])
            train_indices.extend(indices[self._num_val_samples:])
        label_tensor = torch.tensor(labels, device=self._device)
        imgs = imgs.to(self._device)
        train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
        val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)

        return train_task, val_task

    def _make_meta_batch(self, imgs, labels):
        batches = []
        inner_batch_size = (self._total_samples_per_class
                            * self._num_classes_per_batch)
        for i in range(0, len(imgs) - 1, inner_batch_size):
            batch_imgs = imgs[i:i+inner_batch_size]
            batch_labels = labels[i:i+inner_batch_size]
            batch = self._make_single_batch(batch_imgs, batch_labels)
            batches.append(batch)

        train_tasks, val_tasks = zip(*batches)

        return train_tasks, val_tasks

    def __iter__(self):
        for imgs, labels in iter(self._dataloader):
            train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
            yield train_tasks, val_tasks
