import numpy as np
from util import *
import random


def get_rand_vec(dims):
    x = np.random.standard_normal(dims)
    return x / np.linalg.norm(x)


epsilon = 1e-2

global function_query
function_query = 0


# coordinate-wise gradient estimator
def nabla_ff(f, x, mu, index):
    der_x = []
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + mu] + x[i + 1:], index) - f(x[:i] + [x[i] - mu] + x[i + 1:], index)
        der_x.append(der_x_i / (2 * mu))
        global function_query
    function_query = function_query + len(x) * 2 * len(index)
    return der_x


# random gradient estimator
def nabla_ff_rand(f, x, mu, index):
    global function_query
    u = np.random.uniform(-1, 1, len(x))
    u = u / np.linalg.norm(u)
    der_x = (f(x + mu * u, index) - f(x - mu * u, index)) * u/ (2 * mu)
    function_query = function_query + 2 * len(index)

    return der_x


# gaussian gradient estimator
def nabla_ff_gaussian(f, x, mu, index):
    global function_query
    u = np.random.normal(0, 1, len(x))
    der_x = (f(x + mu * u, index) - f(x, index)) / mu * u
    function_query = function_query + 2 * len(index)

    return der_x


# zeroth-order negative curvature finding in stochastic setting
def zo_ncf_weak(f, num, x_0, L, rho):
    # initialization
    delta = np.sqrt(rho * epsilon)
    d = len(x_0)
    p = 0.01
    C_0 = 0.1
    eta = delta / (np.power(C_0, 2) * L**2 * np.log(d/p))
    T = (C_0**2 * np.log(d/p)) / (eta * delta)
    T = T.astype(int)
    sigma = (eta**2 * delta**3) / (np.power(d/p, 3*C_0) * rho)
    r = np.power(d/p, C_0) * sigma

    xi = sigma * get_rand_vec(len(x_0))
    x = np.array(x_0) + xi
    for t in range(T):
        # mu_t = np.linalg.norm(x - x_0)
        mu_t = 0.01
        index = np.random.randint(0, num, 1)
        x_new = x - eta * (np.array(nabla_ff(f, list(x), mu_t, index)) - np.array(nabla_ff(f, list(x_0), mu_t, index)))
        norm_x_to_x_0 = np.linalg.norm(x_new - x_0)
        if norm_x_to_x_0 >= r:
            v_negative_curvature = (x - x_0)/np.linalg.norm(x - x_0)
            return v_negative_curvature
        x = x_new
    v = np.array(np.ones(len(x)))
    return v


def zo_ncf_online(f, num, x_0, L, rho, p):
    delta = np.sqrt(rho * epsilon)
    d = len(x_0)
    t = np.log(1/p)
    t = t.astype(int)
    for j in range(t):
        v_j = zo_ncf_weak(f, num, x_0, L, rho)
        if (v_j != np.array(np.ones(d))).all():
            m = (L**2 * np.log(1/p)) / delta**2
            m = m.astype(int)
            v_new = delta / (len(x_0) * rho) * v_j
            index = np.random.randint(0, num, m)
            h_v_est = np.array(nabla_ff(f, list(x_0 + v_new), 0.01, index)) - \
                      np.array(nabla_ff(f, list(x_0), 0.01, index))
            z_j = np.dot(v_new, h_v_est) / np.linalg.norm(v_new)**2

            if z_j <= -3*delta/4:
                v = v_j
                return v

    v = np.array(np.ones(len(x_0)))
    return v


# zo-sgd-ncf
def zo_sgd_ncf(f, num, batch_size, x_0, iters, p, L, rho):
    x = x_0
    d = len(x_0)
    x_list = []
    values = []
    function_query_complexity = []
    full_index = np.random.randint(0, num, num)
    eta = 1/(3 * L)
    count = 0
    global function_query
    for t in range(iters):
        muu = epsilon / L / np.sqrt(d)/10
        index = np.random.randint(0, num, batch_size)
        der = nabla_ff(f, x, muu, index)
        der = np.array(der)
        x_new = np.array(x) - eta * der

        if t % 10 == 0:
            x_list.append(x_list)
            values.append(f(x, full_index))
            function_query_complexity.append(function_query)
        norm_der = np.linalg.norm(der)
        if norm_der <= 3*epsilon/4:
            v = zo_ncf_online(f, num, x, L, rho, p)
            count = count + 1
            print('times of escaping saddle points: ', count)
            if (v == np.array(np.ones(d))).all():
                # x_new = x_new
                break
            else:
                for j in range(d):
                    flag = random.choice((-1, 1))
                    v[j] = flag * v[j]
                x_new = np.array(x_new) + v

        x = list(x_new)

    print(x)
    function_query = 0
    return function_query_complexity, values


# zo-scsg-ncf_rand
def zo_scsg_ncf_rand(f, num, batch_size, mini_batch_size, x_0, epoch, p, L, rho):
    d = len(x_0)
    full_index = np.random.randint(0, num, num)
    eta = 1 / (10 * L)
    x_list = []
    values = []
    function_query_complexity = []
    x_tilde = x_0
    count = 0
    global function_query
    for t in range(epoch):
        x_list.append(x_tilde)
        values.append(f(x_tilde, full_index))
        function_query_complexity.append(function_query)
        mu = epsilon / L / np.sqrt(d) / 10
        index_batch = np.random.randint(0, num, batch_size)
        v_t = np.array(nabla_ff(f, list(x_tilde), mu, index_batch))
        epoch_size = np.random.geometric(mini_batch_size / batch_size)
        x = x_tilde
        for i in range(epoch_size):
            index_mini_batch = np.random.randint(0, num, mini_batch_size)
            v = np.zeros(d)
            for j in range(len(index_mini_batch)):
                v = v + (np.array(nabla_ff_rand(f, list(x), mu, index_mini_batch[j:j+1])) - \
                np.array(nabla_ff_rand(f, list(x_tilde), mu, index_mini_batch[j:j+1])) + v_t) / len(index_mini_batch)
            x = x - eta * v

        x_tilde_temp = x
        index_verify = np.random.randint(0, num, batch_size)
        der_x_tilde = np.array(nabla_ff(f, list(x_tilde), mu, index_verify))
        norm_der_x_tilde = np.linalg.norm(der_x_tilde)
        if norm_der_x_tilde <= 3*epsilon/4:
            v = zo_ncf_online(f, num, x, L, rho, p)
            count = count + 1
            print('times of escaping saddle points: ', count)
            if (v == np.array(np.ones(d))).all():
                break
            else:
                for j in range(d):
                    flag = random.choice((-1, 1))
                    v[j] = flag * v[j]
                x_tilde_temp = np.array(x_tilde_temp) + v

        x_tilde = list(x_tilde_temp)

    print(x_tilde)
    function_query = 0
    return function_query_complexity, values


# zo-scsg-ncf_coord
def zo_scsg_ncf_coord(f, num, batch_size, mini_batch_size, x_0, epoch, p, L, rho):
    d = len(x_0)
    full_index = np.random.randint(0, num, num)
    eta = 1 / (4 * L)
    x_list = []
    values = []
    function_query_complexity = []
    x_tilde = x_0
    count = 0
    global function_query
    for t in range(epoch):
        x_list.append(x_tilde)
        values.append(f(x_tilde, full_index))
        function_query_complexity.append(function_query)
        mu = epsilon / L / np.sqrt(d) / 10
        index_batch = np.random.randint(0, num, batch_size)
        v_t = np.array(nabla_ff(f, list(x_tilde), mu, index_batch))
        epoch_size = np.random.geometric(mini_batch_size / batch_size)
        x = x_tilde
        for i in range(epoch_size):
            index_mini_batch = np.random.randint(0, num, mini_batch_size)
            v = np.array(nabla_ff(f, list(x), mu, index_mini_batch)) - \
                np.array(nabla_ff(f, list(x_tilde), mu, index_mini_batch)) + v_t
            x = x - eta * v

        x_tilde_temp = x
        index_verify = np.random.randint(0, num, batch_size)
        der_x_tilde = np.array(nabla_ff(f, list(x_tilde), mu, index_verify))
        norm_der_x_tilde = np.linalg.norm(der_x_tilde)
        if norm_der_x_tilde <= 3*epsilon/4:
            v = zo_ncf_online(f, num, x, L, rho, p)
            count = count + 1
            print('times of escaping saddle points: ', count)
            if (v == np.array(np.ones(d))).all():
                break
            else:
                for j in range(d):
                    flag = random.choice((-1, 1))
                    v[j] = flag * v[j]
                x_tilde_temp = np.array(x_tilde_temp) + v

        x_tilde = list(x_tilde_temp)

    print(x_tilde)
    function_query = 0
    return function_query_complexity, values


# zo-spider-ncf
def zo_spider_ncf(f, num, batch_size, mini_batch_size, x_0, epoch_j, epoch_size, p, L, rho):
    d = len(x_0)
    mu = epsilon / np.power(d, 0.25)
    delta = np.sqrt(rho * epsilon)
    full_index = np.random.randint(0, num, num)
    # eta = epsilon / L
    eta = 1 / (15 * L)
    x_list = []
    values = []
    function_query_complexity = []
    x_old = x_0
    x_new = x_0
    # count = 0
    global function_query
    ite_k = 0
    epoch_size_k = int(delta * L / (rho * epsilon)/10)
    for j in range(epoch_j):
        x_list.append(x_new)
        values.append(f(x_new, full_index))
        function_query_complexity.append(function_query)
        # w_1 = zo_ncf_online(f, num, x_new, L, rho, p)
        # for w_i in range(d):
        #     flag = random.choice((-1, 1))
        #     w_1[w_i] = flag * w_1[w_i]
        # w = eta * w_1
        for k in range(epoch_size_k):
            if (j * epoch_size_k + k) % epoch_size == 0:
                index_batch = np.random.randint(0, num, batch_size)
                v = np.array(nabla_ff(f, list(x_new), mu, index_batch))
            else:
                index_mini_batch = np.random.randint(0, num, mini_batch_size)
                v = np.array(nabla_ff(f, list(x_new), mu, index_mini_batch)) - \
                    np.array(nabla_ff(f, list(x_old), mu, index_mini_batch)) + v

            # if (w == np.array(np.ones(d))).all():
            #     if np.linalg.norm(v) <= 2 * 10 * epsilon * np.log(4 * (epoch_j * epoch_size_k)):
            #         break
            #     x_old = x_new
            #     x_new = list(np.array(x_new) - eta * (v / np.linalg.norm(v)))
            # else:
            #     x_old = x_new
            #     x_new = list(np.array(x_new) - w)
            norm_v = np.linalg.norm(v)
            norm_thresh = epsilon * np.log(4 * (epoch_j * epoch_size_k))
            # if norm_v <= norm_thresh:
                # break
            x_old = x_new
            x_new = list(np.array(x_new) - eta * (v / np.linalg.norm(v)))

    print(x_new)
    function_query = 0
    return function_query_complexity, values


# zo-stochastic-cubic-regulatization-newton
def zo_scrn(f, num, x_0, iters, batch_g, batch_h, L, rho):
    d = len(x_0)
    x = x_0
    mu = epsilon / L / (d + 3)
    x_list = []
    values = []
    function_query_complexity = []
    full_index = np.random.randint(0, num, num)
    global function_query
    for k in range(iters):
        x_list.append(x)
        values.append(f(x, full_index))
        function_query_complexity.append(function_query)
        g = np.zeros(d)
        index_g = np.random.randint(0, num, batch_g)
        for i in range(batch_g):
            g = g + np.array(nabla_ff_gaussian(f, list(x), mu, np.array(index_g[i:i+1])) / batch_g)
        h = np.diag(np.zeros(d))
        index_h = np.random.randint(0, num, batch_h)
        for j in range(batch_h):
            u = np.random.normal(0, 1, d)
            h = h + ((f(x + mu * u, index_h[j:j+1]) - f(x - mu * u, index_h[j:j+1]) + \
                      2 * f(x, index_h[j:j+1])) / (2 * mu**2) * (np.outer(u, u) - np.identity(d))) / batch_h
        # r = np.linalg.norm(x)/1e4
        r = 0
        x = list(np.array(x) - np.dot(np.linalg.inv(h + rho / 2 * r * np.identity(d)), g))

    print(x)
    function_query = 0
    return function_query_complexity, values









