import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate

from utils import Uniform

class Sinusoid():

  def __init__(self, amplitude, phase):
    self.amplitude = amplitude
    self.phase = phase

  def __call__(self, x):
    return self.amplitude * np.sin(x - self.phase)


class SinusoidSampler():

  def __init__(self, ampl_lower, ampl_upper, phase_lower, phase_upper):
    self.ampl_prior = Uniform(ampl_lower, ampl_upper)
    self.phase_prior = Uniform(phase_lower, phase_upper)
  
  def sample_sinusoid(self):
    return Sinusoid(self.ampl_prior.sample(), self.phase_prior.sample())
  
  def sample_episode(self, n_support, n_query, sigma_data, sinusoid=None):
    if sinusoid is None:
      sinusoid = self.sample_sinusoid()
    x_supp_samples = np.random.uniform(-5.0, 5.0, n_support).astype(np.float32)
    y_supp = sinusoid(x_supp_samples) + np.random.normal(0, sigma_data, n_support)

    x_query_samples = np.random.uniform(-5.0, 5.0, n_query).astype(np.float32)
    y_query = sinusoid(x_query_samples) + np.random.normal(0, sigma_data, n_query)
    episode = {
      'support_im': x_supp_samples,
      'support_labels': y_supp,
      'query_im': x_query_samples,
      'query_labels': y_query,
    }
    return episode, sinusoid

class SinusoidDataSet(Dataset):

  def __init__(self, ampl_lower, ampl_upper, phase_lower, phase_upper, n_tasks, n_support, n_query, obs_noise=0.1):
    self.n_tasks = n_tasks
    self.n_support = n_support
    self.n_query = n_query
    self._points_per_episode = self.n_support + self.n_query

    self.obs_noise = obs_noise
    self.sinusoid_sampler = SinusoidSampler(ampl_lower, ampl_upper, phase_lower, phase_upper)

    self.images, self.labels, self.sinusoids = self.sample_dataset()
  
  def sample_dataset(self):
    images = []
    labels = []
    sinusoids = []
    for t in range(self.n_tasks):
      episode, sinusoid = self.sinusoid_sampler.sample_episode(self.n_support, self.n_query, self.obs_noise)
      sinusoids.append(sinusoid)
      images.extend(episode['support_im'])
      images.extend(episode['query_im'])
      labels.extend(episode['support_labels'])
      labels.extend(episode['query_labels'])
    return images, labels, sinusoids
  
  def __getitem__(self, idx):
    return self.images[idx], self.labels[idx]

  def __len__(self):
    return len(self.images)

class SinusoidEpisodeBatchSampler:

  def __init__(self, episodes_per_batch, n_support, n_query, n_tasks, shuffle=True):
    self.episodes_per_batch = episodes_per_batch
    self.n_support = n_support
    self.n_query = n_query
    self.n_tasks = n_tasks

    self.shuffle = shuffle

  
  def __iter__(self):
    if self.shuffle:
      episode_idxs = np.random.permutation(self.n_tasks)
    else:
      episode_idxs = range(self.n_tasks)
    total_per_ep = self.n_support + self.n_query
    
    batch_count = 0
    item_idxs = []
    for ep_idx in episode_idxs:
      batch_count += 1

      for i in range(total_per_ep):
        item_idxs.append(ep_idx * total_per_ep + i)

      if batch_count == self.episodes_per_batch:
        yield item_idxs
        item_idxs = []
        batch_count = 0
    if len(item_idxs) > 0:
      yield item_idxs

  def get_episode(self, subbatch):
    return {
      "support_im": default_collate([subbatch[i][0] for i in range(self.n_support)]),
      "support_labels": default_collate([subbatch[i][1] for i in range(self.n_support)]),
      "query_im": default_collate([subbatch[i][0] for i in range(self.n_support, self.n_support + self.n_query)]),
      "query_labels": default_collate([subbatch[i][1] for i in range(self.n_support, self.n_support + self.n_query)]),
    }

  def collate(self, batch):
    episodes = []
    total_episodes = len(batch) // (self.n_support + self.n_query)
    for i in range(total_episodes):
      episodes.append(
        self.get_episode(
          batch[
            i
            * (self.n_support + self.n_query) : (i + 1)
            * (self.n_support + self.n_query)
          ]
        )
      )

    return episodes
