import numpy as np


def sample(mu):
    return np.random.normal(mu, scale=1)


def update_mean_online(mean_support: np.ndarray, mean_hat: np.ndarray, new_val: np.ndarray):
    return (1 / (1 + mean_support)) * (mean_support * mean_hat + new_val)


def kl(x, y):
    return (x - y) ** 2


def beta(delta: float, t: int):
    return np.log(1 / delta) + np.log(np.log(t) + 1)


def solve_optim(w: np.ndarray, mu: np.ndarray, eps: float):
    t1 = (w[0] + w[1]) * eps ** 2
    t2 = (w[2] + w[3]) * eps ** 2

    return max([t1, t2])


def get_answer(w: np.ndarray, mu: np.ndarray, eps: float):
    t1 = (w[0] + w[1]) * eps ** 2 / 2
    t2 = (w[2] + w[3]) * eps ** 2 / 2

    if t1 > t2:
        return np.array([mu[0], mu[1]])
    return np.array([mu[2], mu[3]])


def compute_xt_stas(nt: np.ndarray, mu_hat: np.ndarray, t: int):
    return [
        mu_hat[0] + np.sqrt(np.log(t) / nt[0]),
        mu_hat[1] + np.sqrt(np.log(t) / nt[1]),
        mu_hat[2] + np.sqrt(np.log(t) / nt[2]),
        mu_hat[3] + np.sqrt(np.log(t) / nt[3])
    ]


def compute_xt_ours(nt: np.ndarray, mu_hat: np.ndarray, t: int):
    return [
        [mu_hat[0] + np.sqrt(np.log(t) / nt[0]), mu_hat[0] + np.sqrt(np.log(t) / nt[0])],
        [mu_hat[1] + np.sqrt(np.log(t) / nt[1]), mu_hat[1] + np.sqrt(np.log(t) / nt[1])],
        [mu_hat[2] + np.sqrt(np.log(t) / nt[2]), mu_hat[2] + np.sqrt(np.log(t) / nt[2])],
        [mu_hat[3] + np.sqrt(np.log(t) / nt[3]), mu_hat[3] + np.sqrt(np.log(t) / nt[3])]
    ]


def stopping_rule(nt: np.ndarray, mu_hat: np.ndarray, delta: float, eps: float, t: int):
    return solve_optim(nt / t, mu_hat, eps) >= beta(delta, t) / t


def sticky_weights(nt: np.ndarray, mu_hat: np.ndarray, t: int):
    X_t = compute_xt_stas(nt, mu_hat, t)

    # Select answer according to the total order:
    above = True if X_t[0] > X_t[2] else False
    if above:
        return np.array([0.5, 0.5, 0., 0.])

    return np.array([0.0, 0., 0.5, 0.5])


def find_closet(min_val, max_val, target_val):
    if min_val < target_val < max_val:
        return [target_val, 0]

    if target_val < min_val:
        return [min_val, min_val - target_val]

    return [max_val, target_val - max_val]


def sequence_weights(nt: np.ndarray, mu_hat: np.ndarray, t, prev_answer):
    X_t = compute_xt_ours(nt, mu_hat, t)

    if prev_answer is None:
        if np.random.rand() <= 0.5:
            return np.array([0.5, 0.5, 0., 0.]), [X_t[0][1], X_t[1][1]]
        else:
            return np.array([0.0, 0.0, 0.5, 0.5]), [X_t[2][1], X_t[3][1]]

    # Find closest answer to prev_answer
    a1 = [X_t[0], X_t[1]]
    a2 = [X_t[1], X_t[2]]

    # Find closest answer in a1
    closest_x_a1, dist_x_a1 = find_closet(a1[0][0], a1[0][1], prev_answer[0])
    closest_y_a1, dist_y_a1 = find_closet(a1[1][0], a1[1][1], prev_answer[1])
    closest_a1 = (closest_x_a1, closest_y_a1)
    dist_a1 = max(dist_x_a1, dist_y_a1)

    # Find closest answer in a2
    closest_x_a2, dist_x_a2 = find_closet(a2[0][0], a2[0][1], prev_answer[0])
    closest_y_a2, dist_y_a2 = find_closet(a2[1][0], a2[1][1], prev_answer[1])
    closest_a2 = (closest_x_a2, closest_y_a2)
    dist_a2 = max(dist_x_a2, dist_y_a2)

    # Compute new weights
    if dist_a1 < dist_a2:
        return np.array([0.5, 0.5, 0., 0.]), closest_a1

    return np.array([0.0, 0., 0.5, 0.5]), closest_a2


def forced_exploration(weights: np.array, t: int, n_arms=4) -> np.array:
    unif_w = np.array([1 / n_arms for _ in range(n_arms)])

    gamma = 1 / (4 * np.sqrt(t))
    return (1 - gamma) * weights + gamma * unif_w
