import numpy as np
import random

from util import *


def get_rand_vec(dims):
    x = np.random.standard_normal(dims)
    return x / np.linalg.norm(x)


epsilon = 1e-2

global function_query

function_query = 0


# central coordinate-wise gradient estimator
def nabla_central(f, x, mu):
    der_x = []
    global function_query
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + mu] + x[i + 1:]) - f(x[:i] + [x[i] - mu] + x[i + 1:])
        der_x.append(der_x_i / (2 * mu))
    function_query = function_query + 2 * len(x)
    return der_x


def negative_curvature_exploitation(f, x, v, s):
    if np.linalg.norm(v) >= s:
        x_new = x
    else:
        delta = s * v / np.linalg.norm(v)
        if f(list(x - delta)) - f(list(x + delta)) >= 0:
            x_new = x + delta
        else:
            x_new = x - delta
    v_new = np.zeros(len(v))
    return x_new, v_new


def zo_p_agd(f, x_0, iters, L, rho):
    #initialization
    x = x_0
    y = x
    d = len(x)
    mu = 0.001
    eta = 1 / L
    kappa = L / np.sqrt(rho * epsilon)
    theta = 1 / (4 * np.sqrt(kappa))
    gamma = theta**2 / eta
    s = gamma / (4 * rho)
    t_epoch = int(np.sqrt(kappa))
    # r = epsilon
    r = 1e-2
    v = np.zeros(d)
    t_perturbation = -t_epoch-1

    values = []
    function_query_complexity = []
    global function_query

    for t in range(iters):
        x = list(x)
        values.append(f(x))
        function_query_complexity.append(function_query)
        if t % 100 == 0:
            print(t)
        der_x = nabla_central(f, list(x), mu)
        if np.linalg.norm(der_x) <= 3 / 4 * epsilon and t - t_perturbation > t_epoch:
            xi = uniform_distribution_over_unit_ball(x, r)
            x = x + xi
            y = x
            t_perturbation = t
        # AGD
        nabla_y = np.array(nabla_central(f, list(y), mu))
        x_new = y - eta * nabla_y
        v = x_new - x
        y = x_new + (1 - theta) * v
        if f(list(x_new)) <= f(list(y)) + np.dot(nabla_y, x_new - y) - gamma / 2 * np.linalg.norm(y - x_new)**2:
            x_new, v = negative_curvature_exploitation(f, x_new, v, s)
            y = x_new + (1 - theta) * v

        # update
        x = x_new

    function_query = 0
    return function_query_complexity, values


def zo_p_agd_ancf(f, x_0, iters, L, rho):
    #initialization
    x = x_0
    y = x
    x_tilde = x_0
    zeta = np.zeros(len(x))
    d = len(x)
    mu = 0.01
    eta = 1 / L
    kappa = L / np.sqrt(rho * epsilon)
    theta = 1 / (4 * np.sqrt(kappa))
    gamma = theta**2 / eta
    s = gamma / (4 * rho)
    t_epoch = int(np.sqrt(kappa)/2)
    r = 1e-2
    # r = epsilon
    v = np.zeros(d)
    t_perturbation = -t_epoch-1

    values = []
    function_query_complexity = []
    global function_query

    for t in range(iters):
        values.append(f(list(x)))
        function_query_complexity.append(function_query)
        if t % 100 == 0:
            print(t)
        der_x = nabla_central(f, list(x), mu)
        if np.linalg.norm(der_x) <= 3 / 4 * epsilon and t - t_perturbation > t_epoch:
            x_tilde = x
            xi = uniform_distribution_over_unit_ball(x, r)
            x = x_tilde + xi
            y = x
            zeta = np.array(nabla_central(f, list(x_tilde), mu))
            t_perturbation = t
        if t_perturbation != -t_epoch-1 and t - t_perturbation == t_epoch:
            e = (x - x_tilde) / np.linalg.norm(x - x_tilde)
            # if f(list(x_tilde - 1/4*np.sqrt(epsilon/rho)*e)) <= f(list(x_tilde + 1/4*np.sqrt(epsilon/rho)*e)):
            #     x = x_tilde - 1/4 * np.sqrt(epsilon/rho)*e
            # elif f(list(x_tilde + 1/4*np.sqrt(epsilon/rho)*e)) <= f(list(x)):
            #     x = x_tilde + 1/4 * np.sqrt(epsilon/rho)*e
            if f(list(x_tilde - 0.1*e)) <= f(list(x_tilde + 0.1*e)):
                x = x_tilde - 0.1*e
            elif f(list(x_tilde + 0.1*e)) <= f(list(x)):
                x = x_tilde + 0.1*e
            y = x
            zeta = np.zeros(d)
        # AGD and A_NCF
        nabla_y = np.array(nabla_central(f, list(y), mu))
        x_new = y - eta * (nabla_y - zeta)
        v = x_new - x
        y = x_new + (1 - theta) * v
        if t_perturbation != -t_epoch-1 and t - t_perturbation < t_epoch:
            y = x_tilde + r * (y - x_tilde) / (np.linalg.norm(y - x_tilde))
            x_new = x_tilde + r * (x_new - x_tilde) / (np.linalg.norm(x_new - x_tilde))
        # NCE
        elif f(list(x_new)) <= f(list(y)) + np.dot(nabla_y, x_new - y) - gamma / 2 * np.linalg.norm(y - x_new)**2:
            x_new, v = negative_curvature_exploitation(f, x_new, v, s)
            y = x_new + (1 - theta) * v

        # update
        x = x_new

    function_query = 0
    return function_query_complexity, values

###################################################################################################################
# coordinate-wise gradient estimator
def nabla_ff(f, x, mu):
    der_x = []
    global function_query
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + mu] + x[i + 1:]) - f(x[:i] + [x[i] - mu] + x[i + 1:])
        der_x.append(der_x_i / (2 * mu))
    function_query = function_query + len(x) * 2
    return der_x


# zeroth-other negative curvature finding in deterministic setting
def ncf(f, x, L, rho):
    # initialization
    d = len(x)
    p = 0.01
    delta = np.sqrt(rho*epsilon)
    C_1 = 0.8
    iterss = np.power(C_1, 2)*np.log(d/p)*np.sqrt(L)/np.sqrt(delta)
    iterss = iterss.astype(int)
    sigma = np.power(d / p, -2*C_1) * delta/np.power(iterss, 2)/rho
    rr = np.power(len(x)/p, C_1)*sigma

    x_0 = np.array(x)
    xi = sigma * get_rand_vec(len(x))
    x = np.array(x) + xi
    y_0 = np.array(np.zeros(len(x)))
    y = xi
    v = np.array(np.ones(len(x)))
    for t in range(iterss):
        # mu = np.linalg.norm(y)
        mu = 1e-3
        M_y = - 1/L*(np.array(nabla_ff(f, list(x_0+y), mu)) - np.array(nabla_ff(f, list(x_0), mu))) + (1 - 3*delta/(4*L))*y
        yy = 2 * M_y - y_0
        x = x_0 + yy - M_y

        y_0 = y
        y = yy
        norm_x_to_x_0 = np.linalg.norm(x - x_0)
        if norm_x_to_x_0 >= rr:
            v_negative_curvature = (x - x_0)/np.linalg.norm(x - x_0)
            return v_negative_curvature

    v = np.array(np.ones(len(x)))
    return v


def zo_gd_ncf(f, x_0, iters, L, rho):
    x = x_0
    d = len(x)
    delta = np.sqrt(rho * epsilon)
    values = []
    function_query_complexity = []
    global function_query
    for i in range(iters):
        x = list(x)
        values.append(f(x))
        function_query_complexity.append(function_query)
        muu = epsilon/L/np.sqrt(d)
        der_x = nabla_ff(f, list(x), muu)
        der_x = np.array(der_x)
        eta = 1 / (1 * L)
        x = np.array(x) - eta * der_x
        norm_der = np.linalg.norm(der_x)
        if norm_der <= (3 / 4) * epsilon:
            v = ncf(f, x, L, rho)
            if (v == np.array(np.ones(d))).all():
                x = x
                # break
            else:
                if f(list(x - 0.1 * v)) <= f(list(x + 0.1 * v)):
                    x = x - 0.1 * v
                elif f(list(x + 0.1 * v)) <= f(list(x)):
                    x = x + 0.1 * v
                x = x

    # print(x)
    function_query = 0
    return function_query_complexity, values


###################################################################################################################
# uniform distribution over a ball with radius
def uniform_distribution_over_unit_ball(x_0, r):
    ratio = np.random.uniform(0, 1, 1)
    xi = np.random.uniform(-1, 1, len(x_0))
    xi = xi / np.linalg.norm(xi) * r * ratio[0:1]

    return xi


def nabla_f(f, x):
    der_x = []
    h = 0.01
    global function_query
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + h] + x[i + 1:]) - f(x[:i] + [x[i] - h] + x[i + 1:])
        der_x.append(der_x_i / (2 * h))

    function_query = function_query + 2 * len(x)
    return der_x


def pagd(f, x_0, iters, L):
    #  initialization
    gamma = 1
    t_thresh = 10
    t_noise = - t_thresh - 1
    g_thresh = np.exp(1) * gamma / 100
    r = 0.001

    x = x_0
    d = len(x)
    values = []
    function_query_complexity = []
    global function_query
    eta = 1 / L
    for i in range(iters):
        values.append(f(x))
        function_query_complexity.append(function_query)
        der_x = nabla_f(f, x)
        der_x = np.array(der_x)

        if (3 / 4) * g_thresh >= np.linalg.norm(der_x) and (i - t_noise > t_thresh):
            noise = r * get_rand_vec(len(x))
            x = np.array(x) + noise
            x = list(x)
            der_x = nabla_f(f, x)
            der_x = np.array(der_x)
            t_noise = i

        x_new = np.array(x) - eta * der_x
        x = list(x_new)

    print(x)
    function_query = 0
    return function_query_complexity, values


###################################################################################################################
def nabla_spsa(f, x, c):
    delta = np.random.binomial(n=1, p=0.5, size=len(x))
    delta[delta == 0] = -1
    der_x = []
    global function_query
    dif = f(list(x + c * delta)) - f(list(x - c * delta))
    for i in range(len(x)):
        der_x.append(dif / (2 * c * delta[i]))
    function_query = function_query + 2
    return der_x


def dfpi_spsa(f, x, c, L):
    d = len(x)
    s = np.random.uniform(-1, 1, d)
    s = s / np.linalg.norm(s)
    # T_dfpi = int(1 / epsilon**(2/3) * L * np.log(d))
    T_dfpi = 100
    r = 0.01
    # c = 0.1
    for t in range(T_dfpi):
        g_pos = np.array(nabla_spsa(f, list(x + r * s), c))
        g_neg = np.array(nabla_spsa(f, list(x - r * s), c))
        eta = 1 / L
        s = s - eta * (g_pos - g_neg) / (2 * r)
        s = s / np.linalg.norm(s)

    return s


def dfpi(f, x, c, L):
    d = len(x)
    s = np.random.uniform(-1, 1, d)
    s = s / np.linalg.norm(s)
    # T_dfpi = int(1 / epsilon**(2/3) * L * np.log(d))
    T_dfpi = 20
    r = 0.01
    # c = 0.1
    for t in range(T_dfpi):
        g_pos = np.array(nabla_central(f, list(x + r * s), c))
        g_neg = np.array(nabla_central(f, list(x - r * s), c))
        eta = 1 / L
        s = s - eta * (g_pos - g_neg) / (2 * r)
        s = s / np.linalg.norm(s)

    return s


def rspi(f, x_0, iters, L, sigma_1, sigma_2, T_sigma_1, ratio):
    x = x_0
    d = len(x_0)
    c_init = 0.1
    c = c_init
    values = []
    function_query_complexity = []
    index_ncf = 0
    global function_query
    for k in range(iters):
        values.append(f(x))
        function_query_complexity.append(function_query)
        print('iteration:', k)
        s_1 = np.random.uniform(-1, 1, d)
        s_1 = s_1 / np.linalg.norm(s_1)
        x_current = x
        x_forward = x + sigma_1 * s_1
        x_backward = x - sigma_1 * s_1
        list_value_1 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
        index = np.argmin(list_value_1)
        function_query = function_query + 3
        if index == 0:
            index_ncf += 1
            print('times of ncf:', index_ncf)
            # x = x_current
            s_2 = dfpi(f, x, c, L)
            c = c_init / np.power(k+1, 0.1)
            x_current = x
            x_forward = x + sigma_2 * s_2
            x_backward = x - sigma_2 * s_2
            list_value_2 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
            index = np.argmin(list_value_2)
            function_query = function_query + 3
            if index == 0:
                x = x_current
            elif index == 1:
                x = x_forward
            else:
                x = x_backward
        elif index == 1:
            x = x_forward
        else:
            x = x_backward

        if k % T_sigma_1 == 0:
            sigma_1 = sigma_1 * ratio
        x = list(x)

    print(x)
    function_query = 0
    return function_query_complexity, values


def rspi_spsa(f, x_0, iters, L, sigma_1, sigma_2, T_sigma_1, ratio):
    x = x_0
    d = len(x_0)
    c_init = 0.15
    c = c_init
    values = []
    function_query_complexity = []
    index_ncf = 0
    global function_query
    for k in range(iters):
        values.append(f(x))
        function_query_complexity.append(function_query)
        # print('iteration:', k)
        s_1 = np.random.uniform(-1, 1, d)
        s_1 = s_1 / np.linalg.norm(s_1)
        x_current = x
        x_forward = x + sigma_1 * s_1
        x_backward = x - sigma_1 * s_1
        list_value_1 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
        index = np.argmin(list_value_1)
        function_query = function_query + 3
        if index == 0:
            index_ncf += 1
            # print('times of ncf:', index_ncf)
            # x = x_current
            s_2 = dfpi_spsa(f, x, c, L)
            c = c_init / np.power(k+1, 0.1)
            x_current = x
            x_forward = x + sigma_2 * s_2
            x_backward = x - sigma_2 * s_2
            list_value_2 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
            index = np.argmin(list_value_2)
            function_query = function_query + 3
            if index == 0:
                x = x_current
            elif index == 1:
                x = x_forward
            else:
                x = x_backward
        elif index == 1:
            x = x_forward
        else:
            x = x_backward

        if k % T_sigma_1 == 0:
            sigma_1 = sigma_1 * ratio
        x = list(x)

    print(x)
    function_query = 0
    return function_query_complexity, values


