import numpy
import torch


class Xs:
    def __init__(self):
        self.train = self.generate_sets(500, 2, 5, seed=0)
        self.valid = self.generate_sets(100, 2, 5, seed=1)
        self.test_small = self.generate_sets(100, 2, 5, seed=2)
        self.test_large = self.generate_sets(100, 10, 13, seed=3)

    def generate_one_set(self, size, val_low, val_high, feat_dim=1):  # value range: [val_low, val_high]
        return numpy.random.rand(size, feat_dim) * (val_high - val_low) + val_low

    def generate_sets(self, num, size_low, size_high, feat_dim=1, seed=0):  # size range: [size_low, size_high)
        numpy.random.seed(seed=seed)
        xs = numpy.array([self.generate_one_set(numpy.random.randint(size_low, size_high),
                                                -5, 5, feat_dim=feat_dim) for _ in range(num)])
        return xs

    @classmethod
    def calc_ys(cls, xs, func):
        return torch.Tensor(numpy.array([func(x) for x in xs], dtype=numpy.float64))


xs = Xs()


def func1(x): return x.sum(axis=0)


def func2(x):
    tmp_sum = (x ** 3).sum(axis=0)
    return numpy.sign(tmp_sum) * numpy.power(numpy.abs(tmp_sum), 1. / 3)


def func3(x): return numpy.prod(x, axis=0)
def func4(x): return numpy.sqrt((x ** 2).sum(axis=0))
def func5(x): return x.max(axis=0)
def func6(x): return (x * (x > 0)).sum(axis=0)


def func7(X):  # x + y + x * y / 2
    ret = X[0, :]
    for i in range(1, len(X)):
        ret = ret + X[i, :] + ret * X[i, :] / 2
    return ret


def func8(X):  # 2 + 2(x + y) + xy
    ret = X[0, :]
    for i in range(1, len(X)):
        ret = 2 + 2 * (ret + X[i, :]) + ret * X[i, :]
    return ret


# x + y + 1
def func9(X): return X.sum(axis=0) + len(X) - 1
def func10(X): return numpy.array([1], dtype=numpy.float64)


def func11(X):
    sqrt3 = numpy.sign(X) * numpy.power(numpy.abs(X), 1. / 3)
    return sqrt3.sum(axis=0) ** 3
