import numpy as np
import random

from .util import construct_g2
from .util import construct_g1
from .util import construct_g
from .util import construct_f

# construct f
tau = np.exp(1)
L = np.exp(1)
gamma = 1
rho = L

g2 = construct_g2(L, gamma, tau)
g1 = construct_g1(L, gamma, tau)
g = construct_g(g1, g2)
nu = - g1(2 * tau) + 4 * L * tau ** 2

f = construct_f(L, gamma, tau, nu, g, g1)


def get_rand_vec(dims):
    x = np.random.standard_normal(dims)
    return x / np.linalg.norm(x)


epsilon = 1e-4


# coordinate-wise gradient estimator
def nabla_ff(x, mu):
    der_x = []
    for i in range(len(x)):
        der_x_i = f(x[:i] + [x[i] + mu] + x[i + 1:]) - f(x[:i] + [x[i] - mu] + x[i + 1:])
        der_x.append(der_x_i / (2 * mu))
    return der_x


# zeroth-other negative curvature finding
def ncf(x):
    # initialization
    d = len(x)
    p = 0.01
    # rho = L
    delta = np.sqrt(rho*epsilon)
    C_1 = 0.5
    iterss = np.power(C_1, 2)*np.log(d/p)*np.sqrt(L)/np.sqrt(delta)
    iterss = iterss.astype(int)
    sigma = np.power(d / p, -2*C_1) * delta/np.power(iterss, 2)/rho
    rr = np.power(len(x)/p, C_1)*sigma

    x_0 = np.array(x)
    xi = sigma * get_rand_vec(len(x))
    x = np.array(x) + xi
    y_0 = np.array(np.zeros(len(x)))
    y = xi
    v = np.array(np.ones(len(x)))
    for t in range(iterss):
        # mu = np.linalg.norm(y)
        mu = 1e-2
        M_y = - 1/L*(np.array(nabla_ff(list(x_0+y), mu)) - np.array(nabla_ff(list(x_0), mu))) + (1 - 3*delta/(4*L))*y
        yy = 2 * M_y - y_0
        x = x_0 + yy - M_y

        y_0 = y
        y = yy

        if np.linalg.norm(x - x_0) >= rr:
            v_negative_curvature = (x - x_0)/np.linalg.norm(x - x_0)
            return v_negative_curvature

    v = np.array(np.ones(len(x)))
    return v


eta = 1/(2 * L)


def positive_or_negative():
    if random.random() < 0.5:
        return 1
    else:
        return -1


def experiment_zo_ncf_gd(x_0, iters):
    x = x_0
    d = len(x)
    values = []
    count = 0
    for i in range(iters):
        muu = epsilon/L/np.sqrt(d)/10
        der_x = nabla_ff(x, muu)
        der_x = np.array(der_x)

        x_new = np.array(x) - eta * der_x
        # print(np.linalg.norm(der_x) - (3 / 4) * epsilon)

        if np.linalg.norm(der_x) <= (3 / 4) * epsilon:
            v = ncf(x_new)
            if (v == np.array(np.ones(d))).all():
                # x_new = x_new
                break
            else:
                count = count + 1
                print('times of escaping saddle points: ', count)
                for j in range(d):
                    flag = random.choice((-1, 1))
                    v[j] = flag * v[j]
                x_new = np.array(x_new) + v

        x = list(x_new)
        values.append(f(x))

    return values

