# -*- coding: utf-8 -*-
# ==========================================
# Title:  syntheticFunctions.py
# Author: Binxin Ru and Ahsan Alvi
# Date:   20 August 2019
# Link:   https://arxiv.org/abs/1906.08878
# ==========================================

import numpy as np

# =============================================================================
# Rosenbrock Function (f_min = 0)
# https://www.sfu.ca/~ssurjano/rosen.html
# =============================================================================


def myrosenbrock(X):
    """
    optimal x: (1, 1)
    optimal y: 0
    """
    X = np.asarray(X)
    X = X.reshape((-1, 2))
    if len(X.shape) == 1:  # one observation
        x1 = X[0]
        x2 = X[1]
    else:  # multiple observations
        x1 = X[:, 0]
        x2 = X[:, 1]
    fx = 100 * (x2 - x1 ** 2) ** 2 + (x1 - 1) ** 2
    return fx.reshape(-1, 1) / 300


def rosenbrock(ht_list, X):
    budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
                   5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}
    f = myrosenbrock(X) + budget_dict[ht_list[0]]
    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])
    return y.astype(float)


# =============================================================================
#  Six-hump Camel Function (f_min = - 1.0316 )
#  https://www.sfu.ca/~ssurjano/camel6.html
# =============================================================================


def mysixhumpcamp(X):
    X = np.asarray(X)
    X = np.reshape(X, (-1, 2))
    if len(X.shape) == 1:
        x1 = X[0]
        x2 = X[1]
    else:
        x1 = X[:, 0]
        x2 = X[:, 1]
    term1 = (4 - 2.1 * x1 ** 2 + (x1 ** 4) / 3) * x1 ** 2
    term2 = x1 * x2
    term3 = (-4 + 4 * x2 ** 2) * x2 ** 2
    fval = term1 + term2 + term3
    return fval.reshape(-1, 1) / 10


def sixhumpcamp(ht_list, X):
    budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
                   5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}
    f = mysixhumpcamp(X) + budget_dict[ht_list[0]]
    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])
    return y.astype(float)


# =============================================================================
# Beale function (f_min = 0)
# https://www.sfu.ca/~ssurjano/beale.html
# =============================================================================


def mybeale(X):
    X = np.asarray(X) / 2
    X = X.reshape((-1, 2))
    if len(X.shape) == 1:
        x1 = X[0] * 2
        x2 = X[1] * 2
    else:
        x1 = X[:, 0] * 2
        x2 = X[:, 1] * 2
    fval = (1.5 - x1 + x1 * x2) ** 2 + (2.25 - x1 + x1 * x2 ** 2) ** 2 + (
        2.625 - x1 + x1 * x2 ** 3) ** 2
    return fval.reshape(-1, 1) / 50


def beale(ht_list, X):
    budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
                   5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}
    f = mybeale(X) + budget_dict[ht_list[0]]
    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])
    return y.astype(float)


def func2C(ht_list, X):
    # ht is a categorical index
    # X is a continuous variable
    X = X * 2

    # assert len(ht_list) == 2
    ht1 = ht_list[0]
    ht2 = ht_list[1]

    if ht1 == 0:  # rosenbrock
        f = myrosenbrock(X)
    elif ht1 == 1:  # six hump
        f = mysixhumpcamp(X)
    elif ht1 == 2:  # beale
        f = mybeale(X)

    if ht2 == 0:  # rosenbrock
        f = f + myrosenbrock(X)
    elif ht2 == 1:  # six hump
        f = f + mysixhumpcamp(X)
    else:
        f = f + mybeale(X)

    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])
    return y.astype(float)


def budget_func2C(ht_list, X):
    # ht is a categorical index
    # X is a continuous variable
    X = X * 2
    # assert len(ht_list) == 2
    ht1 = ht_list[0]
    ht2 = ht_list[1]
    ht3 = ht_list[2]

    budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
                   5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}
    if ht1 == 0:  # rosenbrock
        f = myrosenbrock(X)
    elif ht1 == 1:  # six hump
        f = mysixhumpcamp(X)
    elif ht1 == 2:  # beale
        f = mybeale(X)

    if ht2 == 0:  # rosenbrock
        f = f + myrosenbrock(X)
    elif ht2 == 1:  # six hump
        f = f + mysixhumpcamp(X)
    else:
        f = f + mybeale(X)

    print(f"f:{f.shape}")

    f = f + budget_dict[ht3]

    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])
    return y.astype(float)


def func3C(ht_list, X):
    # ht is a categorical index
    # X is a continuous variable
    X = np.atleast_2d(X)
    assert len(ht_list) == 3
    ht1 = ht_list[0]
    ht2 = ht_list[1]
    ht3 = ht_list[2]

    X = X * 2
    if ht1 == 0:  # rosenbrock
        f = myrosenbrock(X)
    elif ht1 == 1:  # six hump
        f = mysixhumpcamp(X)
    elif ht1 == 2:  # beale
        f = mybeale(X)

    if ht2 == 0:  # rosenbrock
        f = f + myrosenbrock(X)
    elif ht2 == 1:  # six hump
        f = f + mysixhumpcamp(X)
    else:
        f = f + mybeale(X)

    if ht3 == 0:  # rosenbrock
        f = f + 5 * mysixhumpcamp(X)
    elif ht3 == 1:  # six hump
        f = f + 2 * myrosenbrock(X)
    else:
        f = f + ht3 * mybeale(X)

    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])

    return y.astype(float)


def budget_func3C(ht_list, X):
    # ht is a categorical index
    # X is a continuous variable
    X = np.atleast_2d(X)
    # assert len(ht_list) == 3
    ht1 = ht_list[0]
    ht2 = ht_list[1]
    ht3 = ht_list[2]
    ht4 = ht_list[3]

    budget_dict = {0: 10000.0, 1: 1000.0, 2: 100.0, 3: 10.0, 4: 1.0,
                   5: 0.1, 6: 0.01, 7: 0.001, 8: 0.0001, 9: 0.00001}

    X = X * 2
    if ht1 == 0:  # rosenbrock
        f = myrosenbrock(X)
    elif ht1 == 1:  # six hump
        f = mysixhumpcamp(X)
    elif ht1 == 2:  # beale
        f = mybeale(X)

    if ht2 == 0:  # rosenbrock
        f = f + myrosenbrock(X)
    elif ht2 == 1:  # six hump
        f = f + mysixhumpcamp(X)
    else:
        f = f + mybeale(X)

    if ht3 == 0:  # rosenbrock
        f = f + 5 * mysixhumpcamp(X)
    elif ht3 == 1:  # six hump
        f = f + 2 * myrosenbrock(X)
    else:
        f = f + ht3 * mybeale(X)

    f += budget_dict[ht4]

    y = f + 1e-6 * np.random.rand(f.shape[0], f.shape[1])

    return y.astype(float)
