import glob
import random
from collections import defaultdict, namedtuple

import torch
import numpy as np
from torch.utils.data import DataLoader
# from torchvision import transforms

# from maml.sampler import ClassBalancedSampler
# from datasets.metadataset import Task


Task = namedtuple('Task', ['x', 'y', 'task_info'])

class QuadraticMetaDataset(object):
    """
    TODO: Check if the data loader is fast enough.
    Args:
        root: path to bird dataset
        img_side_len: images are scaled to this size
        num_classes_per_batch: number of classes to sample for each batch
        num_samples_per_class: number of samples to sample for each class
            for each batch. For K shot learning this should be K + number
            of validation samples
        num_total_batches: total number of tasks to generate
        train: whether to create data loader from the test or validation data
    """
    def __init__(self, name='quadratic',num_total_batches=200000,
            nt=30,nv=2,meta_batch_size=3, nr=20,
            w_0=0, nu=0.5, p=60, sigma=1,
            train=True,device='cuda'):
        self.name = name
        self.nt = nt
        self.nv = nv
        self.meta_batch_size = meta_batch_size
        self.nr = nr
        self.w_0 = w_0
        self.nu = nu
        self.p = p
        self.sigma = sigma
        self._num_total_batches = num_total_batches
        # self._num_samples_per_function = num_samples_per_function
        # self._num_val_samples = num_val_samples
        # self._amp_range = amp_range
        # self._phase_range = phase_range
        # self._input_range = input_range
        # self._oracle = oracle
        self._meta_batch_size = meta_batch_size
        self._train = train
        self._device = device

        self.input_size = p
        self.output_size = p
    
    def _make_meta_batch(self):
        batches = []
        for i in range(self._meta_batch_size):
            w_task = np.random.normal(loc=0, scale=self.nu/np.sqrt(self.p), size=(self.p, 1))
            X_t = np.random.normal(loc=0, scale=self.sigma, size=(self.nt, self.p))
            X_v = np.random.normal(loc=0, scale=self.sigma, size=(self.nv, self.p))
            bias = np.random.normal(loc=0, scale=self.sigma)
            y_t = np.random.normal(loc=(X_t @ w_task + bias)**2, scale=self.sigma)
            y_v = np.random.normal(loc=(X_v @ w_task + bias)**2, scale=self.sigma)
            
            # amp = np.random.uniform(self._amp_range[0], self._amp_range[1])
            # phase = np.random.uniform(self._phase_range[0], self._phase_range[1])
            # init_inputs = np.random.uniform(self._input_range[0], self._input_range[1], [self._num_samples_per_function + self._num_val_samples, 1])
            # outputs = amp * np.sin(init_inputs - phase)
            train_task = Task(torch.from_numpy(X_t).float().to(self._device), 
                            torch.from_numpy(y_t).float().to(self._device), self.name)
            val_task = Task(torch.from_numpy(X_v).float().to(self._device), 
                            torch.from_numpy(y_v).float().to(self._device), self.name)
            batches.append((train_task, val_task))

        train_tasks, val_tasks = zip(*batches)

        return train_tasks, val_tasks

    def __iter__(self):

        # outputs = np.zeros([self._meta_batch_size, self._num_samples_per_function + self._num_val_samples, 1])
        # init_inputs = np.zeros([self._meta_batch_size, self._num_samples_per_function + self._num_val_samples, 1])
        for _ in range(self._num_total_batches):
            train_tasks, val_tasks = self._make_meta_batch()
            yield train_tasks, val_tasks

