import numpy as np
from scipy.optimize import minimize
from util import create_plot


GAMMA = .05
ETA = .0005
T = 400

ALPHA = .85
X, Y = np.meshgrid(np.arange(0, 1, .002), np.arange(0, 1, .002))

def v(x, y):
    return -(x - y)**2 + (x + y) - (x**2 + y**2)

def neg_v(x, y):
    return -v(x, y)

def c(x, y):
    """Constraint function: c(x,y) = 1 - x*y. Feasible when c(x,y) >= ALPHA."""
    return 1 - (x * y)

def compute_nash_gap(x, y):
    # Player 1 optimizes over x1, keeping y fixed
    # Constraint: c(x1, y) >= ALPHA, i.e., c(x1, y) - ALPHA >= 0
    result_opt_x = minimize(
        neg_v,
        x0=[x],
        args=(y,),
        constraints=[{
            'type': 'ineq',  # means: fun(x) >= 0
            'fun': lambda x1, y=y: c(x1[0], y) - ALPHA
        }],
        bounds=[(0, 1)]
    )

    # Player 2 optimizes over x2, keeping x fixed
    # Constraint: c(x, x2) >= ALPHA
    result_opt_y = minimize(
        lambda y_arr, x=x: neg_v(x, y_arr[0]),
        x0=[y],
        constraints=[{
            'type': 'ineq',
            'fun': lambda x2, x=x: c(x, x2[0]) - ALPHA
        }],
        bounds=[(0, 1)]
    )

    if result_opt_x.success and result_opt_y.success:
        x_opt = result_opt_x.x[0]
        y_opt = result_opt_y.x[0]
        return v(x_opt, y) - v(x, y), v(x, y_opt) - v(x, y)
    else:
        raise RuntimeError("Optimization failed: " +
                          (result_opt_x.message if not result_opt_x.success else result_opt_y.message))

V = v(X, Y)
V_c = c(X, Y)
# Log barrier: log(c(x,y) - ALPHA), defined where c > ALPHA (feasible interior)
LOGB = np.log(V_c - ALPHA)

def vdx(x, y):
    # d/dx [-(x-y)^2 + (x+y) - (x^2+y^2)] = -4x + 2y + 1
    return -4 * x + 2 * y + 1

def vdy(x, y):
    # d/dy [-(x-y)^2 + (x+y) - (x^2+y^2)] = 2x - 4y + 1
    return -4 * y + 2 * x + 1

def cdx(x, y):
    # d/dx [1 - x*y] = -y
    return -y

def cdy(x, y):
    # d/dy [1 - x*y] = -x
    return -x

def gamma_val(x, y):
    """Adaptive step size proportional to distance from constraint boundary."""
    return GAMMA * (c(x, y) - ALPHA)

def lb_grad_x(x, y):
    """Gradient w.r.t. x of: v(x,y) + ETA * log(c(x,y) - ALPHA)"""
    return vdx(x, y) + ETA * cdx(x, y) / (c(x, y) - ALPHA)

def lb_grad_y(x, y):
    """Gradient w.r.t. y of: v(x,y) + ETA * log(c(x,y) - ALPHA)"""
    return vdy(x, y) + ETA * cdy(x, y) / (c(x, y) - ALPHA)


x, y = .05, .3
points = [(x, y)]
r_values = [v(x, y)]
nash_gaps = [compute_nash_gap(x, y)]
c_values = [c(x, y)]

for _ in range(T):
    gm = gamma_val(x, y)
    grad_x, grad_y = lb_grad_x(x, y), lb_grad_y(x, y)

    x = np.clip(x + gm * grad_x, 0, 1)
    y = np.clip(y + gm * grad_y, 0, 1)

    points.append((x, y))
    r_values.append(v(x, y))
    nash_gaps.append(compute_nash_gap(x, y))
    c_values.append(c(x, y))


create_plot(
    'coop',
    points,
    nash_gaps,
    r_values,
    c_values,
    LOGB,
    X,
    Y,
    V,
    ALPHA,
    elev=62,
    azim=-140,
    z_notation=False,
    ylim=(0.82, 1.0),
)