import numpy as np
from scipy.optimize import minimize
from scipy.optimize import minimize
from util import create_plot_N


GAMMA = 0.01
ETA = 0.05
T = 120

N = 5
Xs = np.meshgrid(np.arange(0, 1, 0.002), indexing="ij") * N

ps = [
    lambda x: 5 * x**2 + x,
    lambda x: 4 * x**2 + 2 * x + 1,
    lambda x: 5 * x**2,
    lambda x: 2 * x**2 + x,
    lambda x: 10 * x,
]
dps = [
    lambda x: 5 * x + 1,
    lambda x: 8 * x + 2,
    lambda x: 10 * x,
    lambda x: 4 * x + 1,
    lambda x: 10,
]

rs = [
    (0, 1, 3),
    (0, 2, 4),
    (1, 4),
    (2, 3),
    (3, 4),
]

capacities = [1000, 2.7, 1000, 1000, 7]
active_constr = [False, True, False, False, True]

beta, gamma = 42, 10
u = lambda x: beta * x


def utility(i, x):
    costs = sum(
        [ps[k](sum([x[j] for j in range(len(ps)) if k in rs[j]])) for k in rs[i]]
    )
    return u(x[i]) - costs


def constr(x):
    link_loads = [
        sum([x[j] for j in range(len(ps)) if k in rs[j]]) for k in range(len(ps))
    ]
    return [capacities[k] - link_loads[k] for k in range(len(ps))]


def partial_grad_constr(i, x):
    return [-1 if k in rs[i] else 0 for k in range(len(ps))]


def potential_value(x):
    util = sum([u(xi) for xi in x])
    costs = 0
    for k in range(len(ps)):
        s = sum([x[j] for j in range(len(ps)) if k in rs[j]])
        costs += ps[k](s)
    return util - costs


def partial_grad_utility(i, x):
    grad_u = beta

    grad_costs = 0
    for k in rs[i]:
        s = sum(x[j] for j in range(len(ps)) if k in rs[j])
        grad_costs += dps[k](s)

    return grad_u - grad_costs


def lb_grad(i, x):
    grads = partial_grad_constr(i, x)
    gaps = constr(x)
    return partial_grad_utility(i, x) + sum(
        ETA * grads[k] / gaps[k] for k in range(len(gaps)) if active_constr[k]
    )


def gamma(x):
    return GAMMA * min(constr(x))


def compute_nash_gap(xs):
    gaps = []
    for i in range(len(xs)):
        def neg_ui(z):
            x_new = xs.copy()
            x_new[i] = z[0]
            return -utility(i, x_new)

        cons = [
            {
                "type": "ineq",
                "fun": lambda z, i=i, xs=xs, k=k: capacities[k]
                - (
                    sum(xs[j] for j in range(len(ps)) if k in rs[j])
                    - (xs[i] if k in rs[i] else 0)
                    + (z[0] if k in rs[i] else 0)
                ),
            }
            for k in range(len(ps))
        ]

        res = minimize(neg_ui, x0=[xs[i]], bounds=[(0, 10)], constraints=cons)
        if res.success:
            gap = -res.fun - utility(i, xs)
            gaps.append(gap)
        else:
            raise RuntimeError(f"Optimization failed for player {i}: {res.message}")
    return gaps


xs = [0.5] * N
r_values = []
nash_gaps = []
c_values = []

for t in range(T):
    xs = [np.clip(xs[i] + gamma(xs) * lb_grad(i, xs), 0, 10) for i in range(N)]
    us = [utility(i, xs) for i in range(N)]
    cs = constr(xs)
    csa = [-c + capacities[k] for k, c in enumerate(cs) if active_constr[k]]

    p = potential_value(xs)
    gaps = compute_nash_gap(xs)
    r_values.append(np.mean(us))
    nash_gaps.append(max(gaps))
    c_values.append(csa)

create_plot_N(nash_gaps, r_values, c_values, capacities[1], capacities[4])
