import numpy as np
import random

from util import *


def get_rand_vec(dims):
    x = np.random.standard_normal(dims)
    return x / np.linalg.norm(x)


epsilon = 1e-2

global function_query

function_query = 0


# coordinate-wise gradient estimator
def nabla_ff(f, x, mu):
    der_x = []
    global function_query
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + mu] + x[i + 1:]) - f(x[:i] + [x[i] - mu] + x[i + 1:])
        der_x.append(der_x_i / (2 * mu))
    function_query  = function_query + len(x) * 2
    return der_x


# gaussian gradient estimator another version
def nabla_ff_gaussian_1(f, x, sigma):
    global function_query
    u = np.random.normal(0, sigma**2, len(x))
    der_x = (f(x + u) - f(x)) / sigma**2 * u
    function_query = function_query + 2

    return der_x


# zeroth-other negative curvature finding in deterministic setting
def ncf(f, x, L, rho):
    # initialization
    d = len(x)
    p = 0.01
    delta = np.sqrt(rho*epsilon)
    C_1 = 0.5
    iterss = np.power(C_1, 2)*np.log(d/p)*np.sqrt(L)/np.sqrt(delta)
    iterss = iterss.astype(int)
    sigma = np.power(d / p, -2*C_1) * delta/np.power(iterss, 2)/rho
    rr = np.power(len(x)/p, C_1)*sigma

    x_0 = np.array(x)
    xi = sigma * get_rand_vec(len(x))
    x = np.array(x) + xi
    y_0 = np.array(np.zeros(len(x)))
    y = xi
    v = np.array(np.ones(len(x)))
    for t in range(iterss):
        # mu = np.linalg.norm(y)
        mu = 1e-3
        M_y = - 1/L*(np.array(nabla_ff(f, list(x_0+y), mu)) - np.array(nabla_ff(f, list(x_0), mu))) + (1 - 3*delta/(4*L))*y
        yy = 2 * M_y - y_0
        x = x_0 + yy - M_y

        y_0 = y
        y = yy
        norm_x_to_x_0 = np.linalg.norm(x - x_0)
        if norm_x_to_x_0 >= rr:
            v_negative_curvature = (x - x_0)/np.linalg.norm(x - x_0)
            return v_negative_curvature

    v = np.array(np.ones(len(x)))
    return v


def positive_or_negative():
    if random.random() < 0.5:
        return 1
    else:
        return -1


def experiment_zo_ncf_gd(f, x_0, iters, L, rho):
    x = x_0
    d = len(x)
    values = []
    function_query_complexity = []
    count = 0
    global function_query
    for i in range(iters):
        muu = epsilon/L/np.sqrt(d)
        der_x = nabla_ff(f, x, muu)
        der_x = np.array(der_x)
        eta = 1 / (2 * L)
        x_new = np.array(x) - eta * der_x
        values.append(f(x))
        function_query_complexity.append(function_query)
        norm_der = np.linalg.norm(der_x)
        if norm_der <= (3 / 4) * epsilon:
            v = ncf(f, x_new, L, rho)
            if (v == np.array(np.ones(d))).all():
                break
            else:
                count = count + 1
                print('times of escaping saddle points: ', count)
                for j in range(d):
                    flag = random.choice((-1, 1))
                    v[j] = flag * v[j]
                x_new = np.array(x_new) + v

        x = list(x_new)

    print(x)
    function_query = 0
    return function_query_complexity, values


# uniform distribution over a ball with radius
def uniform_distribution_over_unit_ball(x_0, r):
    xi = np.random.uniform(-1, 1, len(x_0))
    xi = xi / np.linalg.norm(xi) * r

    return xi


# zeroth-order perturbed sgd
def zpsgd(f,  x_0, iters, batch_size, rho, L):
    d = len(x_0)
    x = x_0
    sigma = np.sqrt(epsilon / (rho * d))
    r = np.exp(1) / 100
    eta = 2 / L
    x_list = []
    values = []
    function_query_complexity = []
    global function_query
    for t in range(iters):
        x_list.append(x)
        values.append(f(x))
        function_query_complexity.append(function_query)

        g = np.zeros(d)
        for i in range(batch_size):
            g = g + np.array(nabla_ff_gaussian_1(f, list(x), sigma) / batch_size)
        xi = uniform_distribution_over_unit_ball(x, r)
        x = list(np.array(x) - eta * (g + xi))

    print(x)
    function_query = 0

    return function_query_complexity, values


def nabla_f(f, x):
    der_x = []
    h = 0.01
    global function_query
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + h] + x[i + 1:]) - f(x[:i] + [x[i] - h] + x[i + 1:])
        der_x.append(der_x_i / (2 * h))

    function_query = function_query + 2 * len(x)
    return der_x


def experiment_pagd(f, x_0, iters, L):
    #  initialization
    gamma = 1
    t_thresh = -1
    t_noise = - t_thresh - 1
    g_thresh = np.exp(1) * gamma / 100
    r = np.exp(1) / 1000

    x = x_0
    d = len(x)
    t_noise = - t_thresh - 1
    values = []
    function_query_complexity = []
    global function_query
    eta = 1 / (4 * L)
    for i in range(iters):
        values.append(f(x))
        function_query_complexity.append(function_query)
        der_x = nabla_f(f, x)
        der_x = np.array(der_x)

        if (3 / 4) * g_thresh >= np.linalg.norm(der_x) and (i - t_noise > t_thresh):
            noise = r * get_rand_vec(len(x))
            x = np.array(x) + noise
            x = list(x)
            der_x = nabla_f(f, x)
            der_x = np.array(der_x)
            t_noise = i

        x_new = np.array(x) - eta * der_x
        x = list(x_new)

    print(x)
    function_query = 0
    return function_query_complexity, values


###################################################################################################################
def dfpi(f, x, c, L):
    d = len(x)
    s = np.random.uniform(-1, 1, d)
    s = s / np.linalg.norm(s)
    T_dfpi = int(1 / epsilon**(2/3) * L * np.log(d))
    T_dfpi = 20
    r = 0.01
    # c = 0.1
    for t in range(T_dfpi):
        g_pos = np.array(nabla_ff(f, list(x + r * s), c))
        g_neg = np.array(nabla_ff(f, list(x - r * s), c))
        eta = 1 / L
        s = s - eta * (g_pos - g_neg) / (2 * r)
        s = s / np.linalg.norm(s)

    return s


def rspi(f, x_0, iters, L, sigma_1, sigma_2, T_sigma_1, ratio):
    x = x_0
    d = len(x_0)
    # sigma_1 = 0.01
    # sigma_2 = 0.1
    c_init = 0.1
    c = c_init
    values = []
    function_query_complexity = []
    index_ncf = 0
    global function_query
    for k in range(iters):
        values.append(f(x))
        function_query_complexity.append(function_query)
        print('iteration:', k)
        s_1 = np.random.uniform(-1, 1, d)
        s_1 = s_1 / np.linalg.norm(s_1)
        x_current = x
        x_forward = x + sigma_1 * s_1
        x_backward = x - sigma_1 * s_1
        list_value_1 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
        index = np.argmin(list_value_1)
        function_query = function_query + 3
        if index == 0:
            index_ncf += 1
            print('times of ncf:', index_ncf)
            # x = x_current
            s_2 = dfpi(f, x, c, L)
            c = c_init / np.power(k+1, 0.1)
            x_current = x
            x_forward = x + sigma_2 * s_2
            x_backward = x - sigma_2 * s_2
            list_value_2 = [f(list(x_current)), f(list(x_forward)), f(list(x_backward))]
            index = np.argmin(list_value_2)
            function_query = function_query + 3
            if index == 0:
                x = x_current
            elif index == 1:
                x = x_forward
            else:
                x = x_backward
        elif index == 1:
            x = x_forward
        else:
            x = x_backward

        if k % T_sigma_1 == 0 and sigma_1 > 1e-5:
            sigma_1 = sigma_1 * ratio
        x = list(x)

    print(x)
    function_query = 0
    return function_query_complexity, values
