import torch
import numpy as np
from torch.utils.data import Dataset

__all__ = ['FunctionDataSet']

class FunctionDataSet(Dataset):
    def __init__(self, total_iter, num_total_points=400, x_size=1,y_size=1,l1_scale=0.4,sigma_scale=1.0,testing=True, visualize=False):
        self.total_iter = total_iter
        self.num_total_points = num_total_points

        self.GPCurvesGen = GPCurvesReader(x_size, y_size, l1_scale, sigma_scale, testing)

    def __getitem__(self,index):
        target_x, target_y = self.GPCurvesGen.generate_curves(self.num_total_points)
        return target_x, target_y

    def __len__(self):
        return self.total_iter

class GPCurvesReader(object):
  def __init__(self, x_size=1, y_size=1, l1_scale=0.4, sigma_scale=1.0, testing=True):
    self._batch_size = 1
    self._x_size = x_size
    self._y_size = y_size
    self._l1_scale = l1_scale
    self._sigma_scale = sigma_scale
    self._testing = testing

  def _gaussian_kernel(self, xdata, l1, sigma_f, sigma_noise=2e-2):

    num_total_points = xdata.shape[1]

    # Expand and take the difference
    xdata1 = torch.unsqueeze(xdata, dim=1)  # [B, 1, num_total_points, x_size]
    xdata2 = torch.unsqueeze(xdata, dim=2)  # [B, num_total_points, 1, x_size]
    diff = xdata1 - xdata2  # [B, num_total_points, num_total_points, x_size]

    # [B, y_size, num_total_points, num_total_points, x_size]
    norm = torch.pow((diff[:, None, :, :, :] / l1[:, :, None, None, :]),2)

    norm = torch.sum(norm, -1)  # [B, data_size, num_total_points, num_total_points]

    # [B, y_size, num_total_points, num_total_points]
    kernel = torch.pow(sigma_f,2)[:, :, None, None] * torch.exp(-0.5 * norm)

    # Add some noise to the diagonal to make the cholesky work.
    kernel += (sigma_noise**2) * torch.eye(num_total_points)

    return kernel

  def generate_curves(self, num_total_points):
    x_values = torch.unsqueeze(torch.arange(-2., 2., 4. / num_total_points), 0)
    x_values = x_values.repeat([self._batch_size, 1])
    x_values = torch.unsqueeze(x_values, -1)
    
    # Set kernel parameters
    l1 = (torch.ones(self._batch_size, self._y_size, self._x_size) * self._l1_scale)
    sigma_f = torch.ones(self._batch_size, self._y_size) * self._sigma_scale

    # Pass the x_values through the Gaussian kernel
    # [batch_size, y_size, num_total_points, num_total_points]
    kernel = self._gaussian_kernel(x_values, l1, sigma_f)

    # Calculate Cholesky, using double precision for better stability:
    cholesky = torch.cholesky(kernel.type(torch.FloatTensor))

    # Sample a curve
    # [batch_size, y_size, num_total_points, 1]
    y_values = torch.matmul(cholesky, torch.randn([self._batch_size, self._y_size, num_total_points, 1]))

    # [batch_size, num_total_points, y_size]
    y_values = torch.transpose(torch.squeeze(y_values, 3), 1,2)

    # Select the targets
    target_x = x_values.squeeze(0)
    target_y = y_values.squeeze(0)
    return target_x, target_y
