"""
提供三种方法用于初始采样
"""

import pickle
import numpy as np
from numpy.linalg import pinv
import matplotlib.pyplot as plt
import seaborn
seaborn.set_style("white")
EPS = 1e-8

from scipy.optimize import minimize
import random
# np.random.seed(0)


def random_with_bound(low_bounds: np.array, up_bounds: np.array, nb: int) -> np.array:
    """
    Random sample points in low bounds and up bounds.
    Args:
        low_bounds: array (d).
        up_bounds: array (d), must have the same length as low_bounds.
        nb: int (n), number.
    Returns:
        sampled_points: array (n, d), random sampled points.
    """
    assert len(low_bounds) == len(up_bounds)
    d = len(low_bounds)
    sampled_points = np.random.rand(nb, d) * (up_bounds - low_bounds) + low_bounds
    return sampled_points


def branin_hoo_func(x: np.ndarray) -> np.array:
    assert x.ndim == 2 and x.shape[1] == 2
    x1, x2 = x[:, 0], x[:, 1]
    a, b, c, r, s, t = 1.0, 5.1 / (4 * np.pi**2), 5.0 / np.pi, 6.0, 10.0, 1 / (8 * np.pi)
    y = a * (x2 - b * x1**2 + c * x1 - r) ** 2 + s * (1 - t) * np.cos(x1) + s
    return y


def branin_hoo_noise(x: np.ndarray, h: np.ndarray) -> np.array:
    """
    Args:
        x: (n, 2), in [(-5, 0), (10, 15)]
        h: (n,), in [1, 2]
    """
    assert x.ndim == 2 and h.ndim == 1
    x1, x2 = x[:, 0], x[:, 1]
    noise = 5 * h * np.sin(x1 + x2)
    return noise


def lhs(low_bounds: np.array, up_bounds: np.array, nb: int = 6) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size

    ind = np.empty((nb, low_bounds.size))
    for i in range(low_bounds.size):
        ind[:, i] = random.sample(list(range(nb)), nb)
    # print(ind)

    def _sample(j: int, pos: int) -> float:
        l, u = low_bounds[j], up_bounds[j]
        sample = l + (u - l) / nb * pos + (u - l) / nb * np.random.rand()
        # print("j:", j, "pos:", pos, "sample:", sample, l + (u - l) / nb * pos, l + (u - l) / nb * (pos + 1))
        return sample

    x = np.empty((nb, low_bounds.size))
    for i in range(nb):
        x[i, :] = np.array([_sample(j, pos) for j, pos in enumerate(ind[i, :])])
    return x


def init_lhs(bounds, nb=10):
    """随机初始化。
    输入例子，
    bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
            {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
            {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
            {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
            {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]
    """
    data = []
    for feature in bounds:
        if feature["type"] == "categorical":
            data.append(np.random.choice(feature["domain"], size=nb))
        elif feature["type"] == "continuous":
            data.append(np.random.uniform(feature["domain"][0], feature["domain"][1], size=nb))
        else:
            raise RuntimeError
    return np.array(data).T


# LHD划分。
# 将特征的每一维均分，随机分配给每个样本。
def lhd_v1(low_bounds: np.array, up_bounds: np.array, nb: int = 6) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size
    p = low_bounds.size

    samples = np.empty((nb, p))
    for i in range(p):
        tmp = np.linspace(low_bounds[i], up_bounds[i], nb)
        np.random.shuffle(tmp)
        samples[:, i] = tmp
    return samples


def lhd(low_bounds: np.array, up_bounds: np.array, nb_sample: int = 6) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size
    p = low_bounds.size

    def _uniform_shuffle(l: float, u: float):
        delta = (u - l) / nb_sample
        col = np.array([np.random.uniform(l + delta * j, l + delta * (j + 1)) for j in range(nb_sample)])
        np.random.shuffle(col)
        return col

    samples = np.empty((nb_sample, p))
    for i in range(p):
        samples[:, i] = _uniform_shuffle(low_bounds[i], up_bounds[i])
    return samples


def init_lhd(bounds, nb=10):
    """随机初始化。
    输入例子，
    bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
            {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
            {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
            {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
            {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]
    """
    def _uniform_shuffle(l: float, u: float, nb: int):
        delta = (u - l) / nb
        col = np.array([np.random.uniform(l + delta * j, l + delta * (j + 1)) for j in range(nb)])
        np.random.shuffle(col)
        return col

    data = []
    for feature in bounds:
        if feature["type"] == "categorical":
            data.append(np.random.choice(feature["domain"], size=nb))
        elif feature["type"] == "continuous":
            data.append(_uniform_shuffle(feature["domain"][0], feature["domain"][1], nb))
        else:
            raise RuntimeError
    return np.array(data).T


def RBFkernel(Xn1: np.ndarray, Xn2: np.ndarray, params: np.array):
    assert Xn1.ndim == Xn2.ndim == 2 and params.ndim == 1
    assert Xn1.shape[1] == Xn2.shape[1] and params.size == Xn1.shape[1] + 1

    theta = params[:-1]
    tau = params[-1]
    K1 = np.sum(Xn1**2 * theta, axis=1).reshape(-1, 1) + np.sum(Xn2**2 * theta, axis=1) - 2 * np.dot(Xn1 * theta, Xn2.T)
    K1 = np.exp(-K1)
    return tau**2 * K1


# Frechet derivative.
def frechet(Fn: np.ndarray, Psi: np.ndarray) -> float:
    assert Fn.ndim == Psi.ndim == 2
    assert Fn.shape[0] == Psi.shape[0] == Psi.shape[1]

    trace = np.trace(Fn.T.dot(pinv(Psi)).dot(Fn))
    return trace


# Optimal Initial design.
def opd(kernel: callable, low_bounds: np.array, up_bounds: np.array, nb_sample: int = 6, nb_trial: int = 100) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size

    def _feature(Xn: np.ndarray):
        return np.concatenate([np.ones((nb_sample, 1)), Xn, Xn**2], axis=1)

    best_Xn = None
    max_frechet = -float('inf')
    for i in range(nb_trial):
        Xn = lhd(low_bounds, up_bounds, nb_sample)
        Fn = _feature(Xn)
        Psi = kernel(Xn, Xn)
        metric = frechet(Fn, Psi)

        if max_frechet < metric:
            max_frechet = metric
            best_Xn = Xn
            print(f"[{i}/{nb_trial}], metric:{metric}")
    return best_Xn, max_frechet


def init_opd(kernel: callable, bounds, nb=10, nb_trial=100):
    best_Xn = None
    max_frechet = -float("inf")
    for i in range(nb_trial):
        Xn = init_lhd(bounds, nb)
        Psi = kernel(Xn, Xn)
        metric = frechet(Xn, Psi)

        if max_frechet < metric:
            max_frechet = metric
            best_Xn = Xn
    return best_Xn, max_frechet


def opdnc(kernel: callable, low_bounds: np.array, up_bounds: np.array, nb_sample: int = 6, nb_trial: int = 2000) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size

    def _feature(Xn: np.ndarray):
        return np.concatenate([np.ones((nb_sample, 1)), Xn, Xn**2], axis=1)

    best_Xn = None
    max_frechet = -float('inf')
    for i in range(nb_trial):
        Xn = np.random.uniform(low_bounds, up_bounds, size=(nb_sample, low_bounds.size))
        Fn = _feature(Xn)
        Psi = kernel(Xn, Xn)
        metric = frechet(Fn, Psi)

        if max_frechet < metric:
            max_frechet = metric
            best_Xn = Xn
            print(f"[{i}/{nb_trial}], metric:{metric}")
    return best_Xn, max_frechet


def init_opdnc(kernel: callable, bounds, nb=10, nb_trial=2000):
    best_Xn = None
    max_frechet = -float('inf')
    for i in range(nb_trial):
        Xn = init_lhs(bounds, nb)
        Psi = kernel(Xn, Xn)
        metric = frechet(Xn, Psi)

        if max_frechet < metric:
            max_frechet = metric
            best_Xn = Xn
            print(f"[{i}/{nb_trial}], metric:{metric}")
    return best_Xn, max_frechet

# Frechet derivative with one more sample.


def frechetplus(Fn: np.ndarray, Psi: np.ndarray, Fx: np.ndarray) -> float:
    assert Fn.ndim == Psi.ndim == Fx.ndim == 2
    assert Fn.shape[0] == Psi.shape[0] == Psi.shape[1] and Fn.shape[1] == Fx.shape[1] and Fx.shape[0] == 1
    n, p = Fn.shape

    Fx_plus = np.concatenate([Fx, np.zeros((n - 1, p))], axis=0)

    below_trace = np.trace(Fn.T.dot(pinv(Psi)).dot(Fn))
    above_trace = np.trace(Fx_plus.T.dot(pinv(Psi)).dot(Fn) + Fn.T.dot(pinv(Psi)).dot(Fx_plus))
    return above_trace / below_trace - 2


def opd2(kernel: callable, low_bounds: np.array, up_bounds: np.array, nb_sample: int = 6, nb_trial: int = 1000) -> np.ndarray:
    assert low_bounds.ndim == up_bounds.ndim == 1
    assert low_bounds.size == up_bounds.size
    p = low_bounds.size
    Xn = lhd(low_bounds, up_bounds, nb_sample)

    def _feature(Xn: np.ndarray):
        n, _ = Xn.shape
        return np.concatenate([np.ones((n, 1)), Xn, Xn**2], axis=1)

    def _get_X1(Xn: np.ndarray, kernel: callable):
        Fn = _feature(Xn)
        Psi = kernel(Xn, Xn)

        def min_f(X1: np.array):
            assert X1.ndim == 1
            Fx = _feature(X1[np.newaxis, :])
            return - frechetplus(Fn, Psi, Fx)

        res = minimize(min_f,
                       x0=np.random.rand(p),
                       bounds=[(u, l) for u, l in zip(low_bounds, up_bounds)],  # ((1e-3, 1e3), (1e-3, 1e3)),
                       )
        return res.x, res.fun

    def _update_Xn(Xn: np.ndarray, X1: np.array, kernel: callable):
        updated = False
        max_frechet = frechet(_feature(Xn), kernel(Xn, Xn))
        best_Xn = Xn

        for j in range(nb_sample):
            Xn_dummy = Xn.copy()
            Xn_dummy[j, :] = X1

            value = frechet(_feature(Xn_dummy), kernel(Xn_dummy, Xn_dummy))
            if max_frechet < value:
                max_frechet = value
                best_Xn = Xn_dummy
                updated = True
        return best_Xn, updated, max_frechet

    # 近似模拟退火。如果没有更新，则以一定概率更新目标。
    T = 1e3
    T_alpha = 0.95
    for i in range(nb_trial):

        # 计算使得Frechet plus最小的X1, Fx = _feature(X1)
        X1, metric = _get_X1(Xn, kernel)

        if metric < 0:
            print(f"{i} find it. metric:{metric}")
            return Xn, metric

        # 找到Xn中的一列替换X1.
        Xn_new, updated, max_frechet = _update_Xn(Xn, X1, kernel)

        if updated is True:
            Xn = Xn_new
            T *= T_alpha
            print("## updated")
        else:
            if np.random.rand() < np.exp(-1 / T):
                ind = np.random.choice(nb_sample)
                Xn[ind, :] = X1

        if i % 100 == 0 or i + 1 == nb_trial:
            print(f"[{i}/{nb_trial}] min_frechtplus:{metric}, Xn1:{Xn[1,:]}, T:{T}")
            # print(Xn)
    return Xn, max_frechet


if __name__ == "__main__":
    # init_points = random_with_bound(np.array([-5, 0]), np.array([10, 15]), 10000)
    # for point in init_points:
    #     res = minimize(branin_hoo_func, x0=point,
    #                    bounds=((-5, 10), (0, 15)),
    #                    method="l-bfgs-b")
    #     print("potimal x:", (-np.pi, 12.275), (np.pi, 2.275), (9.42478, 2.475), "y:", 0.397887)
    #     print("x:", res.x, "fun:", res.fun)
    # y = branin_hoo_func(init_points)
    # print(y.max(), y.min())

    # Xn1 = np.array([[1, 2], [2, 4]])
    # Xn2 = np.array([[1, 2], [3, 4], [5, 6]])
    # kernel(Xn1, Xn2, params=np.random.rand(6))

    low_bounds = np.array([1, 2])
    up_bounds = np.array([5, 6])
    # samples = opd2(lambda x1, x2: RBFkernel(x1, x2, np.random.rand(3)), low_bounds, up_bounds)

    samples = lhs(low_bounds, up_bounds, nb=4)
    print("## lhs:")
    print(samples)
    # samples = lhd(low_bounds, up_bounds, nb_sample=4)
    # print("## lhd:")
    # print(samples)
    # samples = opd(lambda x1, x2: RBFkernel(x1, x2, np.random.rand(3)), low_bounds, up_bounds, nb_sample=4)
    # print("## opd:")
    # print(samples)
    # samples = opdnc(lambda x1, x2: RBFkernel(x1, x2, np.random.rand(3)), low_bounds, up_bounds, nb_sample=4)
    # print("## opdnc:")
    # print(samples)

    bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
              {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
              {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
              {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
              {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]
    sample = init_lhs(bounds)
    print(f"sample:{sample.shape} {sample}")

    sample = init_lhd(bounds, 4)
    print(f"sample:{sample.shape} {sample}")
