import glob
import random
from collections import defaultdict

import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms

from maml.sampler import ClassBalancedSampler
from maml.datasets.metadataset import Task

class SinusoidMetaDataset(object):
    """
    TODO: Check if the data loader is fast enough.
    Args:
        root: path to bird dataset
        img_side_len: images are scaled to this size
        num_classes_per_batch: number of classes to sample for each batch
        num_samples_per_class: number of samples to sample for each class
            for each batch. For K shot learning this should be K + number
            of validation samples
        num_total_batches: total number of tasks to generate
        train: whether to create data loader from the test or validation data
    """
    def __init__(self, name='sinusoid',num_total_batches=200000,
            num_samples_per_function=5,num_val_samples=15,meta_batch_size=32,
            amp_range=[0.1, 5.0],phase_range=[0, np.pi],input_range=[-5.0, 5.0],
            oracle=False,train=True,device='cuda'):
        self.name = name
        self._num_total_batches = num_total_batches
        self._num_samples_per_function = num_samples_per_function
        self._num_val_samples = num_val_samples
        self._amp_range = amp_range
        self._phase_range = phase_range
        self._input_range = input_range
        self._oracle = oracle
        self._meta_batch_size = meta_batch_size
        self._train = train
        self._device = device

        self.input_size = 1
        self.output_size = 1
    
    def _make_meta_batch(self):
        batches = []
        for func in range(self._meta_batch_size):
            amp = np.random.uniform(self._amp_range[0], self._amp_range[1])
            phase = np.random.uniform(self._phase_range[0], self._phase_range[1])
            init_inputs = np.random.uniform(self._input_range[0], self._input_range[1], [self._num_samples_per_function + self._num_val_samples, 1])
            outputs = amp * np.sin(init_inputs - phase)
            train_task = Task(torch.from_numpy(init_inputs[:self._num_samples_per_function]).float().to(self._device), 
                            torch.from_numpy(outputs[:self._num_samples_per_function]).float().to(self._device), self.name)
            val_task = Task(torch.from_numpy(init_inputs[self._num_samples_per_function:]).float().to(self._device), 
                            torch.from_numpy(outputs[self._num_samples_per_function:]).float().to(self._device), self.name)
            batches.append((train_task, val_task))

        train_tasks, val_tasks = zip(*batches)

        return train_tasks, val_tasks

    def __iter__(self):

        # outputs = np.zeros([self._meta_batch_size, self._num_samples_per_function + self._num_val_samples, 1])
        # init_inputs = np.zeros([self._meta_batch_size, self._num_samples_per_function + self._num_val_samples, 1])
        for _ in range(self._num_total_batches):
            train_tasks, val_tasks = self._make_meta_batch()
            yield train_tasks, val_tasks

    # for imgs, labels in iter(self._dataloader):
    #     train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
    #     yield train_tasks, val_tasks
    # self._total_samples_per_class = (
    #     num_samples_per_class + num_val_samples)
    # self._dataloader = self._get_bird_data_loader()

    # self.input_size = (img_channel, img_side_len, img_side_len)
    # self.output_size = self._num_classes_per_batch

    # def _get_bird_data_loader(self):
    #     assert self._img_channel == 1 or self._img_channel == 3
    #     resize = transforms.Resize(
    #         (self._img_side_len, self._img_side_len), Image.LANCZOS)
    #     if self._img_channel == 1:
    #         img_transform = transforms.Compose(
    #             [resize, transforms.Grayscale(num_output_channels=1),
    #              transforms.ToTensor()])
    #     else:
    #         img_transform = transforms.Compose(
    #             [resize, transforms.ToTensor()])
    #     dset = BirdMAMLSplit(
    #         self._root, transform=img_transform, train=self._train,
    #         download=True, num_train_classes=self._num_train_classes)
    #     _, labels = zip(*dset._flat_character_images)
    #     sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
    #                                    self._total_samples_per_class,
    #                                    self._num_total_batches, self._train)

    #     batch_size = (self._num_classes_per_batch *
    #                   self._total_samples_per_class *
    #                   self._meta_batch_size)
    #     loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
    #                         num_workers=self._num_workers, pin_memory=True)
    #     return loader

    # def _make_single_batch(self, imgs, labels):
    #     """Split imgs and labels into train and validation set.
    #     TODO: check if this might become the bottleneck"""
    #     # relabel classes randomly
    #     new_labels = list(range(self._num_classes_per_batch))
    #     random.shuffle(new_labels)
    #     labels = labels.tolist()
    #     label_set = set(labels)
    #     label_map = {label: new_labels[i] for i, label in enumerate(label_set)}
    #     labels = [label_map[l] for l in labels]

    #     label_indices = defaultdict(list)
    #     for i, label in enumerate(labels):
    #         label_indices[label].append(i)

    #     # assign samples to train and validation sets
    #     val_indices = []
    #     train_indices = []
    #     for label, indices in label_indices.items():
    #         val_indices.extend(indices[:self._num_val_samples])
    #         train_indices.extend(indices[self._num_val_samples:])
    #     label_tensor = torch.tensor(labels, device=self._device)
    #     imgs = imgs.to(self._device)
    #     train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
    #     val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)

    #     return train_task, val_task

    # def _make_meta_batch(self, imgs, labels):
    #     batches = []
    #     inner_batch_size = (
    #         self._total_samples_per_class * self._num_classes_per_batch)
    #     for i in range(0, len(imgs) - 1, inner_batch_size):
    #         batch_imgs = imgs[i:i+inner_batch_size]
    #         batch_labels = labels[i:i+inner_batch_size]
    #         batch = self._make_single_batch(batch_imgs, batch_labels)
    #         batches.append(batch)

    #     train_tasks, val_tasks = zip(*batches)

    #     return train_tasks, val_tasks

