import numpy as np
import gurobipy as gp
from gurobipy import GRB
from ZO_gradient_estimator import estimate_gradient, estimate_gradient_inner_product, uniform_sphere_samples
def first_order_opt(f,h,df, dh, x_start, learning_rate=0.01,r=0.1, epsilon=1e-5, TB=10, max_iter=1000, tol=1e-6, k=1.0):
    x = np.array(x_start)
    history = []

    for iteration in range(max_iter):
        grad_f = df(x)
        grad_h = dh(x)

        h_val = h(x)
        lambda_ = max((k * h_val - np.dot(grad_h, grad_f)),0) / (np.linalg.norm(grad_h) ** 2)

        x_new = x - learning_rate * (grad_f + lambda_ * grad_h)

        history.append((f(x), np.max([h(x),1e-8])))

        if np.linalg.norm(x_new - x) < tol:
            break

        x = x_new
    
    return x, f(x), history

def ZO_baseline(f, h, x_start, learning_rate=0.01,r=0.1, epsilon=1e-5, TB=10, max_iter=1000, tol=1e-6, k=1.0):
    x = np.array(x_start)
    history = []

    for iteration in range(max_iter):
        grad_f = estimate_gradient(f, x, r, TB)
        grad_h = estimate_gradient(h, x, r, TB)

        h_val = h(x)
        lambda_ = max((k * h_val - np.dot(grad_h, grad_f)),0) / (np.linalg.norm(grad_h) ** 2)

        x_new = x - learning_rate * (grad_f + lambda_ * grad_h)

        history.append((f(x), np.max(h(x),0)))

        if np.linalg.norm(x_new - x) < tol:
            break

        x = x_new
    
    return x, f(x), history

def ZOFL(f, h, x_start, learning_rate=0.01, r=0.1,epsilon=1e-5, TB=10, max_iter=1000, tol=1e-6, k=1.0):
    x = np.array(x_start)
    history = []

    for iteration in range(max_iter):
        grad_f = estimate_gradient(f, x, r, TB)
        grad_h = estimate_gradient(h, x, r, TB)

        h_val = h(x)
        dot_grad_h_f = estimate_gradient_inner_product(h, x, grad_f, epsilon)
        dot_grad_h_h = estimate_gradient_inner_product(h, x, grad_h, epsilon)
        lambda_ = max((k * h_val - dot_grad_h_f),0) / (dot_grad_h_h)

        x_new = x - learning_rate * (grad_f + lambda_ * grad_h)

        history.append((f(x), np.max(h(x),0)))

        if np.linalg.norm(x_new - x) < tol:
            break

        x = x_new
    
    return x, f(x), history

def ZOFL_midpoint(f, h, x_start, learning_rate=0.01,r=0.1, epsilon=1e-5, TB=10, max_iter=1000, tol=1e-6, k=1.0):
    x = np.array(x_start)
    history = []
    

    for iteration in range(max_iter):
        directions = uniform_sphere_samples(len(x_start), TB)
        grad_f = estimate_gradient(f, x, r, TB,directions)
        grad_h = estimate_gradient(h, x, r, TB,directions)

        h_val = h(x)
        dot_grad_h_f = estimate_gradient_inner_product(h, x, grad_f, epsilon)
        dot_grad_h_h = estimate_gradient_inner_product(h, x, grad_h, epsilon)
        lambda_ = max((k * h_val - dot_grad_h_f),0) / (dot_grad_h_h)
        x_mid = x - learning_rate/2 * (grad_f + lambda_ * grad_h)
        
        ## Doing mid_point_estimation
        grad_f = estimate_gradient(f, x_mid, r, TB,directions)
        grad_h = estimate_gradient(h, x_mid, r, TB,directions)

        h_val = h(x_mid)
        dot_grad_h_f = estimate_gradient_inner_product(h, x_mid, grad_f, epsilon)
        dot_grad_h_h = estimate_gradient_inner_product(h, x_mid, grad_h, epsilon)
        lambda_ = max((k * h_val - dot_grad_h_f),0) / (dot_grad_h_h)
        x_new = x - learning_rate * (grad_f + lambda_ * grad_h)
        
        history.append((f(x), np.max(h(x),0)))

        if np.linalg.norm(x_new - x) < tol:
            break

        x = x_new
    
    return x, f(x), history

def ZOGDA(f, h, x_start, learning_rate_x=0.01,learning_rate_lambda=0.01,r=0.1, TB=10, max_iter=1000, tol=1e-6):
    x = np.array(x_start)
    history = []
    lambda_ = 0

    for iteration in range(max_iter):
        directions = uniform_sphere_samples(len(x_start), TB)
        grad_f = estimate_gradient(f, x, r, TB,directions)
        grad_h = estimate_gradient(h, x, r, TB,directions)

        h_val = h(x)
        lambda_ = max(0,lambda_ + learning_rate_lambda * (h_val))

        x_new = x - learning_rate_x * (grad_f + lambda_ * grad_h)

        history.append((f(x), np.max(h(x),0)))

        if np.linalg.norm(x_new - x) < tol:
            break

        x = x_new
    
    return x, f(x), history

def ConEx(f, h, x_start, learning_rate_x=0.01,learning_rate_lambda=0.01,r=0.1, TB=10, max_iter=1000, tol=1e-6, theta=0.1):
    x = np.array(x_start)
    history = []
    lambda_ = 0
    h_val_old = h(x)
    mean_x = 0

    for iteration in range(max_iter):
        mean_x = (1/(iteration+1))*x + (iteration/(iteration+1))*mean_x
        grad_f = estimate_gradient(f, x, r, TB)
        grad_h = estimate_gradient(h, x, r, TB)

        h_val = h(x)
        s = (1+theta)*h_val - theta* h_val_old
        
        lambda_ = max(0,lambda_ + learning_rate_lambda * (h_val))

        x_new = x - learning_rate_x * (grad_f + lambda_ * grad_h)

        history.append((f(mean_x), np.max(h(mean_x),0)))

        if np.linalg.norm(x_new - x) < tol:
            break
        
        h_val_old = h_val
        x = x_new
    
    return mean_x, f(mean_x), history

def ConEx_meta(
    f,
    h,
    x_start,
    learning_rate_x=0.01,
    learning_rate_lambda=0.01,
    r=0.1,
    TB=10,
    max_iter_outer=1000,
    max_iter_inner=1000,
    tol=1e-6,
    theta=0.1,
    mu=0.0,
    mu_h=None,
):
   
    x_current = np.array(x_start, dtype=float)
    history_outer = []
    if mu_h is None:
        mu_h = mu

    for _ in range(max_iter_outer):
        x_anchor = x_current.copy()

        def f_prime(x):
            dx = x - x_anchor
            return f(x) + mu * float(np.dot(dx, dx))

        def h_prime(x):
            dx = x - x_anchor
            return h(x) + mu_h * float(np.dot(dx, dx))

        x_current, _, _ = ConEx(
            f_prime,
            h_prime,
            x_current,
            learning_rate_x=learning_rate_x,
            learning_rate_lambda=learning_rate_lambda,
            r=r,
            TB=TB,
            max_iter=max_iter_inner,
            tol=tol,
            theta=theta,
        )

        history_outer.append((f(x_current), abs(h(x_current))))

    return x_current, f(x_current), history_outer