import numpy as np

from simulators.simulator import Simulator


class Homo1(Simulator):
    def __init__(self, n_points=101, noise_sigma=0.2):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 1
        self.n_outputs = 1

    def mean(self, x): return np.sin(6 * x)
    def stddev(self, x): return self.noise_sigma * np.ones_like(x)
    def noise(self, x): return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])
    def search_space(self): return np.linspace(0, 5, self.n_points).reshape(-1, 1)


class Homo2(Simulator):
    def __init__(self, n_points=101, noise_sigma=1.0):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 1
        self.n_outputs = 1

    def mean(self, x): return np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12
    def stddev(self, x): return self.noise_sigma * np.ones_like(x)
    def noise(self, x): return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])
    def search_space(self): return np.linspace(0, 5, self.n_points).reshape(-1, 1)


class GramacyAndLee1d(Simulator):
    """
    Gramacy and Lee 2012 function: https: // www.sfu.ca / ~ssurjano / grlee12.html
    The original function has no noise. Here we use N(0,0.05^2)
    """
    def __init__(self, n_points=101, noise_sigma=0.1):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 1
        self.n_outputs = 1

    def mean(self, x): return np.sin(10*np.pi*x) / (2*x) + np.power(x - 1, 4)
    def stddev(self, x): return self.noise_sigma * np.ones_like(x)
    def noise(self, x): return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])
    def search_space(self): return np.linspace(0.5, 2.5, self.n_points).reshape(-1, 1)


class GramacyAndLee2d(Simulator):
    """
    Gramacy and Lee 2008 function: https://www.sfu.ca/~ssurjano/grlee08.html
    The original function has no noise. Here we use N(0,0.05^2)
    """
    def __init__(self, n_points=100, noise_sigma=0.05):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 2
        self.n_outputs = 1

    def mean(self, x):
        mean = x[:, 0] * np.exp(-np.power(x[:, 0], 2) - np.power(x[:, 1], 2))
        return mean.reshape(-1, 1)

    def stddev(self, x):
        stddev = self.noise_sigma * np.ones_like(x[:, 0])
        return stddev.reshape(-1, 1)

    def noise(self, x):
        return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])

    def search_space(self):
        x1 = np.linspace(-2, 6, self.n_points)
        x2 = np.linspace(-2, 6, self.n_points)
        search_space = np.array(np.meshgrid(x1, x2, indexing='ij')).T.reshape(-1, 2)
        return search_space


class Higdon1d(Simulator):
    """
    Higdon (2002), Gramacy and Lee (2008) function: https://www.sfu.ca/~ssurjano/hig02grlee08.html
    The original function noise: N(0,0.1^2)
    """
    def __init__(self, n_points=101, noise_sigma=0.1):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 1
        self.n_outputs = 1

    def mean(self, x):
        mean = np.piecewise(x,
                            [x < 10, x >= 10],
                            [lambda xx: np.sin(np.pi * xx / 5) + 0.2 * np.cos(4 * np.pi * xx / 5),
                             lambda xx: xx / 10 - 1])
        return mean

    def stddev(self, x): return self.noise_sigma * np.ones_like(x)
    def noise(self, x): return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])
    def search_space(self): return np.linspace(0, 20, self.n_points).reshape(-1, 1)


class Motorcycle(Simulator):
    """
    Mean and stddev list are fitted with an heteroscedastic GP (inducing points)

    :param x: non-transformed data points on the interval [0, 100]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """
    def __init__(self, n_points=None, noise_sigma=None):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 1
        self.n_outputs = 1

    def convert_x_to_list(self, x):
        x = [np.where(self.search_space() == int(tmp_x))[0][0] for tmp_x in x]
        #x = [x] if x.shape == () else x
        return [int(x) for x in x]

    def mean(self, x):
        gp_mean = [0.5107, 0.5032, 0.4939, 0.4843, 0.4765, 0.4718, 0.4713, 0.4748, 0.4805, 0.4856,
                0.4870, 0.4824, 0.4724, 0.4604, 0.4523, 0.4543, 0.4695, 0.4953, 0.5216, 0.5319,
                0.5067, 0.4284, 0.2868, 0.0825, -0.1722, -0.4559, -0.7433, -1.0112, -1.2435, -1.4339,
                -1.5851, -1.7049, -1.8015, -1.8788, -1.9333, -1.9549, -1.9298, -1.8453, -1.6947, -1.4804,
                -1.2140, -0.9140, -0.6013, -0.2950, -0.0080, 0.2536, 0.4900, 0.7046, 0.9006, 1.0778,
                1.2317, 1.3542, 1.4362, 1.4710, 1.4573, 1.4012, 1.3147, 1.2139, 1.1143, 1.0277,
                0.9589, 0.9056, 0.8605, 0.8146, 0.7605, 0.6959, 0.6241, 0.5533, 0.4935, 0.4538,
                0.4395, 0.4501, 0.4795, 0.5173, 0.5507, 0.5678, 0.5600, 0.5241, 0.4632, 0.3869,
                0.3092, 0.2456, 0.2098, 0.2103, 0.2482, 0.3164, 0.4015, 0.4868, 0.5568, 0.6006,
                0.6147, 0.6033, 0.5769, 0.5487, 0.5311, 0.5324, 0.5549, 0.5950, 0.6441, 0.6919,
                0.7285]
        return np.array([gp_mean[x] for x in self.convert_x_to_list(x)]).reshape(-1, 1)

    def stddev(self, x):
        gp_stddev = [0.0477, 0.0400, 0.0463, 0.0514, 0.0529, 0.0502, 0.0438, 0.0366, 0.0330, 0.0328, 0.0327,
                 0.0318, 0.0315, 0.0331, 0.0363, 0.0433, 0.0560, 0.0702, 0.0780, 0.0733, 0.0615, 0.0799,
                 0.1468, 0.2393, 0.3408, 0.4382, 0.5196, 0.5747, 0.5959, 0.5800, 0.5295, 0.4540, 0.3712,
                 0.3067, 0.2822, 0.2915, 0.3074, 0.3135, 0.3156, 0.3360, 0.3899, 0.4658, 0.5385, 0.5873,
                 0.6045, 0.5975, 0.5860, 0.5901, 0.6133, 0.6377, 0.6380, 0.5951, 0.5022, 0.3669, 0.2227,
                 0.1980, 0.3483, 0.5319, 0.6956, 0.8186, 0.8886, 0.8987, 0.8470, 0.7373, 0.5789, 0.3876,
                 0.1898, 0.1067, 0.2472, 0.3799, 0.4620, 0.4847, 0.4492, 0.3643, 0.2463, 0.1253, 0.1068,
                 0.2006, 0.2817, 0.3230, 0.3195, 0.2769, 0.2092, 0.1377, 0.0911, 0.0856, 0.0919, 0.0940,
                 0.1116, 0.1568, 0.2117, 0.2548, 0.2730, 0.2623, 0.2282, 0.1857, 0.1545, 0.1445, 0.1445,
                 0.1437, 0.1542]
        return np.array([gp_stddev[x] for x in self.convert_x_to_list(x)]).reshape(-1, 1)

    def noise(self, x): return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])

    def search_space(self):
        # The search space must be with 101 values in order for the motorcycle simulator to work..
        return np.linspace(0, 100, 101).reshape(-1, 1)


class Branin2d(Simulator):
    """
    Branin function
    With the modification by Forrester et al. (2008).
    The original function has no noise. Following Picheny, we add 5% of the standard deviation of the modified objective
    function as noise, e.g., 2.83
    All references can be found here: https://www.sfu.ca/~ssurjano/branin.html
    """
    def __init__(self, n_points=100, noise_sigma=2.83*4):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 2
        self.n_outputs = 1

    def mean(self, x):
        a = 1
        b = 5.1 / (4 * np.power(np.pi, 2))
        c = 5 / np.pi
        r = 6
        s = 10
        t = 1 / (8 * np.pi)
        mean = a * np.power(x[:, 1] - b * np.power(x[:, 0], 2) + c * x[:, 0] - r, 2) + \
            s * (1 - t) * np.cos(x[:, 0]) + s + 5 * x[:, 0]
        return mean.reshape(-1, 1)

    def stddev(self, x):
        stddev = self.noise_sigma * np.ones_like(x[:, 0])
        return stddev.reshape(-1, 1)

    def noise(self, x):
        return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])

    def search_space(self):
        x1 = np.linspace(-5, 10, self.n_points)
        x2 = np.linspace(0, 15, self.n_points)
        search_space = np.array(np.meshgrid(x1, x2, indexing='ij')).T.reshape(-1, 2)
        return search_space


class Ishigami3d(Simulator):
    """
    Ishigami function
    The original function has no noise. Following Picheny, we add 5% of the standard deviation of the modified objective
    function as noise, e.g., 0.187
    All references can be found here: http://www.sfu.ca/~ssurjano/ishigami.html
    """
    def __init__(self, n_points=100, noise_sigma=0.187):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 3
        self.n_outputs = 1

    def mean(self, x):
        a = 7
        b = 0.1
        mean = np.sin(x[:, 0]) + a * np.power(np.sin(x[:, 1]), 2) + b * np.power(x[:, 2], 4) * np.sin(x[:, 0])
        return mean.reshape(-1, 1)

    def stddev(self, x):
        stddev = self.noise_sigma * np.ones_like(x[:, 0])
        return stddev.reshape(-1, 1)

    def noise(self, x):
        return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])

    def search_space(self):
        x1 = np.linspace(-np.pi, np.pi, self.n_points)
        x2 = np.linspace(-np.pi, np.pi, self.n_points)
        x3 = np.linspace(-np.pi, np.pi, self.n_points)
        search_space = np.array(np.meshgrid(x1, x2, x3, indexing='ij')).T.reshape(-1, 3)
        return search_space


class Hartmann6d(Simulator):
    """
    The original function has no noise. Following Picheny, we add 5% of the standard deviation of the modified objective
    function as noise, e.g., 0.0192
    All references can be found here: http://www.sfu.ca/~ssurjano/ishigami.html
    """
    def __init__(self, n_points=100, noise_sigma=0.0192):
        super().__init__(n_points, noise_sigma)
        self.n_inputs = 6
        self.n_outputs = 1

    def mean(self, x):
        alpha = np.array([1, 1.2, 3, 3.2])
        A = np.array([[10, 3, 17, 3.5, 1.7, 8],
                      [.05, 10, 17, .1, 8, 14],
                      [3, 3.5, 1.7, 10, 17, 8],
                      [17, 8, .05, 10, .1, 14]
                      ])
        P = 1e-4 * np.array([[1312, 1696, 5569, 124, 8283, 5886],
                             [2329, 4135, 8307, 3736, 1004, 9991],
                             [2348, 1451, 3522, 2883, 3047, 6650],
                             [4047, 8828, 8732, 5743, 1091, 381]])
        #inner = np.sum(A * np.power(x - P, 2), axis=1)
        # change x to have to size [batch_size, 1, features]
        x = x.reshape(x.shape[0], 6, 1).transpose(0, 2, 1)
        # add batch_dim to P
        P = P.reshape(4, 6, 1).transpose(2, 0, 1)
        inner = np.sum(np.power(x - P, 2), axis=2)
        mean = - np.sum(alpha * np.exp(- inner), axis=1)
        return mean.reshape(-1, 1)

    def stddev(self, x):
        stddev = self.noise_sigma * np.ones_like(x[:, 0])
        return stddev.reshape(-1, 1)

    def noise(self, x):
        return np.random.normal(loc=0, scale=1, size=[x.shape[0], self.n_outputs])

    def search_space(self):
        """
        This is too big!!
        x1 = np.linspace(0, 1, self.n_points)
        x2 = np.linspace(0, 1, self.n_points)
        x3 = np.linspace(0, 1, self.n_points)
        x4 = np.linspace(0, 1, self.n_points)
        x5 = np.linspace(0, 1, self.n_points)
        x6 = np.linspace(0, 1, self.n_points)
        search_space = np.array(np.meshgrid(x1, x2, x3, x4, x5, x6, indexing='ij')).T.reshape(-1, 6)
        Instead, we cut the size of the space at 10^6 data points
        """
        search_space = np.random.rand(1000000, 6)
        return search_space


def oracle_simulator(args):
    sim = args.simulator
    if sim == "homo1": return Homo1()
    elif sim == "homo2": return Homo2()
    elif sim == "motorcycle": return Motorcycle()
    elif sim == "gramacy1d": return GramacyAndLee1d()
    elif sim == "gramacy2d": return GramacyAndLee2d()
    elif sim == "higdon1d": return Higdon1d()
    elif sim == "branin2d": return Branin2d()
    elif sim == "ishigami3d": return Ishigami3d()
    elif sim == "hartmann6d": return Hartmann6d()
    else:
        raise NotImplementedError(f'Simulator {sim} is not defined.')