import numpy as np
import torch

from utils.initial_design import sample_initial_inputs


class Simulator:
    def __init__(self, n_points, noise_sigma):
        # Number of data points in the discretized data set
        self.n_points = n_points

        # Level of noise
        self.noise_sigma = noise_sigma

        # Number of inputs and outputs
        self.n_inputs = None
        self.n_outputs = None

    def query(self, x):
        y = self.mean(x) + self.stddev(x) * self.noise(x)
        if y.shape[1] == 1:
            y = y.reshape(-1)
        return y

    def mean(self, x):
        raise NotImplementedError

    def stddev(self, x):
        raise NotImplementedError

    def noise(self, x):
        raise NotImplementedError

    def search_space(self):
        # NB: Remember to always have a feature dimension
        raise NotImplementedError

    def sample_initial_data(self, n_samples, space_filling_design, seed=None):
        if seed is not None:
            np.random.seed(seed)

        # Sample training data
        x = np.array(sample_initial_inputs(n_samples, self.search_space(), method=space_filling_design))

        # Get labels from oracle
        y = self.query(x)

        x = torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x
        if len(x.shape) == 1:
            x = x.view(-1, 1)
        y = torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y
        if self.n_outputs == 1:
            y = y.squeeze(-1)

        return x, y