import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linprog

# ----------------------------
# Helpers
# ----------------------------
def prices(k: int) -> np.ndarray:
    return np.arange(1, k + 1, dtype=float) / k

def demand_on_grid(P: np.ndarray, f):
    # allow f to accept numpy arrays or scalars
    try:
        return np.asarray(f(P), dtype=float)
    except Exception:
        return np.vectorize(f)(P).astype(float)

def best_single_price_profit(k: int, c: float, f):
    """Return (max_{x in P}(x-c)f(x), argmax_price) on the discrete grid P={1/k,...,1}."""
    P = prices(k)
    fp = demand_on_grid(P, f)
    g = (P - float(c)) * fp
    idx = int(np.argmax(g))
    return float(g[idx]), float(P[idx])

# ----------------------------
# Build CCE LP pieces once per (c1,c2)
# ----------------------------
def build_cce_problem(k: int, c1: float, c2: float, f):
    """
    n=2 Bertrand with discrete grid P.
    Variables: D[a,b] = Prob(p1=P[a], p2=P[b])  (k^2 vars)
    CCE constraints for BOTH players: constant deviations to any q in P.

    Returns:
      P, u1_flat, u2_flat, A_eq, b_eq, A_ub, b_ub, bounds
    """
    P = prices(k)
    fp = demand_on_grid(P, f)
    g1 = (P - float(c1)) * fp
    g2 = (P - float(c2)) * fp

    A = P[:, None]  # p1 price
    B = P[None, :]  # p2 price

    # u1[a,b] = g1[a] if p1 strictly lower; = g1[a]/2 if tie; else 0
    u1 = np.where(A < B, g1[:, None], np.where(A == B, 0.5 * g1[:, None], 0.0))
    # u2[a,b] = g2[b] if p2 strictly lower (i.e., B < A); = g2[b]/2 if tie; else 0
    u2 = np.where(B < A, g2[None, :], np.where(A == B, 0.5 * g2[None, :], 0.0))

    u1_flat = u1.reshape(-1)  # index = a*k + b
    u2_flat = u2.reshape(-1)

    # sum D = 1
    A_eq = np.ones((1, k * k))
    b_eq = np.array([1.0])

    # CCE constraints:
    # For each q in P:
    #   E[u1] >= E[u1(q, x2)]
    #   E[u2] >= E[u2(x1, q)]
    A_ub_rows = []
    b_ub = []

    # Player 1 deviations: dev payoff depends only on opponent price b
    for q in range(k):
        # payoff vs each opponent b if p1 deviates to P[q]
        u1_dev_vs_b = np.where(P[q] < P, g1[q], np.where(P[q] == P, 0.5 * g1[q], 0.0))  # (k,)
        dev_coeff = np.tile(u1_dev_vs_b, k)  # (k^2,) repeats across a-blocks
        diff = u1_flat - dev_coeff           # want diff @ D >= 0  <=>  -diff @ D <= 0
        A_ub_rows.append(-diff)
        b_ub.append(0.0)

    # Player 2 deviations: dev payoff depends only on opponent price a
    for q in range(k):
        # payoff vs each opponent a if p2 deviates to P[q]
        u2_dev_vs_a = np.where(P[q] < P, g2[q], np.where(P[q] == P, 0.5 * g2[q], 0.0))  # (k,)
        dev_coeff = np.repeat(u2_dev_vs_a, k)  # (k^2,) repeats within each a-block across b
        diff = u2_flat - dev_coeff
        A_ub_rows.append(-diff)
        b_ub.append(0.0)

    A_ub = np.vstack(A_ub_rows)             # (2k, k^2)
    b_ub = np.array(b_ub, dtype=float)
    bounds = [(0.0, 1.0)] * (k * k)

    return P, u1_flat, u2_flat, A_eq, b_eq, A_ub, b_ub, bounds

def solve_cce_maximizing_player(k: int, c1: float, c2: float, f, maximize_player: int):
    """
    Solve for a CCE distribution D that maximizes E[u_maximize_player],
    subject to CCE constraints for BOTH players.

    Returns: (D_star (k x k), U1, U2)
    """
    P, u1_flat, u2_flat, A_eq, b_eq, A_ub, b_ub, bounds = build_cce_problem(k, c1, c2, f)

    if maximize_player == 1:
        obj = -u1_flat  # linprog minimizes
    elif maximize_player == 2:
        obj = -u2_flat
    else:
        raise ValueError("maximize_player must be 1 or 2")

    res = linprog(
        c=obj,
        A_ub=A_ub, b_ub=b_ub,
        A_eq=A_eq, b_eq=b_eq,
        bounds=bounds,
        method="highs",
    )
    if not res.success:
        raise RuntimeError(f"LP failed (maximize_player={maximize_player}, c2={c2}): {res.message}")

    D_flat = res.x
    U1 = float(u1_flat @ D_flat)
    U2 = float(u2_flat @ D_flat)
    return D_flat.reshape((k, k)), U1, U2

# ----------------------------
# Main plotting routine
# ----------------------------

def plot_two_panels_vs_gap(
    k=100, c1=0.0, c2_vals=None, f=None, eps=1e-12,
    save=True, out_dir="outputs", basename="two_panels_vs_gap", dpi=300
):
    """
    For each c2 in c2_vals:
      - Solve CCE that maximizes Player 1 utility -> get (U1,U2) under that CCE
      - Solve CCE that maximizes Player 2 utility -> get (U1,U2) under that CCE

    Saves two plots (max P1, max P2) with filenames that include k and c1 and the c2 range.
    """

    if f is None:
        raise ValueError("Please pass a demand function f(x).")

    if c2_vals is None:
        c2_vals = np.arange(0.0, 0.8 + 1e-9, 0.2)

    def _script_dir():
        try:
            return os.path.dirname(os.path.abspath(__file__))
        except NameError:
            return os.getcwd()

    def _ensure_dir(path):
        os.makedirs(path, exist_ok=True)
        return path

    def _tag_num(x: float) -> str:
        # 0.0 -> "0", 1.25 -> "1p25", -0.5 -> "m0p5"
        s = f"{float(x):g}"
        return s.replace("-", "m").replace(".", "p")

    def _tag_c2_range(vals) -> str:
        vals = np.asarray(list(vals), dtype=float)
        if vals.size == 0:
            return "c2_empty"
        vmin = float(np.min(vals))
        vmax = float(np.max(vals))
        step = float(vals[1] - vals[0]) if vals.size >= 2 else 0.0
        if vals.size >= 3:
            # check approximate uniform spacing
            diffs = np.diff(vals)
            if np.max(np.abs(diffs - diffs[0])) > 1e-9:
                step = np.nan
        if np.isnan(step):
            return f"c2_{_tag_num(vmin)}_to_{_tag_num(vmax)}_nonuniform_{vals.size}"
        return f"c2_{_tag_num(vmin)}_to_{_tag_num(vmax)}_step_{_tag_num(step)}"

    denom1, p1_star = best_single_price_profit(k, c1, f)

    gaps = []
    r1_under_max1, r2_under_max1 = [], []
    r1_under_max2, r2_under_max2 = [], []

    for c2 in c2_vals:
        gap = float(c2 - c1)
        gaps.append(gap)

        denom2, p2_star = best_single_price_profit(k, c2, f)

        # CCE maximizing Player 1
        _, U1_m1, U2_m1 = solve_cce_maximizing_player(k, c1, c2, f, maximize_player=1)
        r1m1 = U1_m1 / denom1 if abs(denom1) > eps else np.nan
        r2m1 = U2_m1 / denom2 if abs(denom2) > eps else np.nan
        r1_under_max1.append(r1m1)
        r2_under_max1.append(r2m1)

        # CCE maximizing Player 2
        _, U1_m2, U2_m2 = solve_cce_maximizing_player(k, c1, c2, f, maximize_player=2)
        r1m2 = U1_m2 / denom1 if abs(denom1) > eps else np.nan
        r2m2 = U2_m2 / denom2 if abs(denom2) > eps else np.nan
        r1_under_max2.append(r1m2)
        r2_under_max2.append(r2m2)

        print(
            f"c2={c2:.1f} (gap={gap:.1f}) | "
            f"denom1={denom1:.6f} (p*={p1_star:.2f}), denom2={denom2:.6f} (p*={p2_star:.2f})\n"
            f"  Max P1 CCE: U1={U1_m1:.6f} r1={r1m1} | U2={U2_m1:.6f} r2={r2m1}\n"
            f"  Max P2 CCE: U1={U1_m2:.6f} r1={r1m2} | U2={U2_m2:.6f} r2={r2m2}\n"
        )

    out_path = None
    if save:
        out_path = _ensure_dir(os.path.join(_script_dir(), out_dir))

    k_tag = _tag_num(k)
    c1_tag = _tag_num(c1)
    c2_tag = _tag_c2_range(c2_vals)
    base_tag = f"{basename}_k_{k_tag}_c1_{c1_tag}_{c2_tag}"

    # --- Plot 1: CCE that maximizes Player 1 ---
    plt.figure()
    plt.plot(gaps, r1_under_max1, marker="o", label="Player 1 ratio (under CCE maximizing P1)")
    plt.plot(gaps, r2_under_max1, marker="o", label="Player 2 ratio (under same CCE)")
    plt.xlabel("Cost gap c2 - c1")
    plt.ylabel(r"Ratio: utility / max$_{x\in\mathcal{P}} (x-c_i)f(x)$")
    plt.title(f"CCE maximizing Player 1 (k={k}, c1={c1})")
    plt.grid(True)
    plt.legend()

    if save:
        png1 = os.path.join(out_path, f"{base_tag}_maxP1.png")
        pdf1 = os.path.join(out_path, f"{base_tag}_maxP1.pdf")
        plt.savefig(png1, dpi=dpi, bbox_inches="tight")
        plt.savefig(pdf1, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot:\n  {png1}\n  {pdf1}")

    plt.show()

    # --- Plot 2: CCE that maximizes Player 2 ---
    plt.figure()
    plt.plot(gaps, r1_under_max2, marker="o", label="Player 1 ratio (under CCE maximizing P2)")
    plt.plot(gaps, r2_under_max2, marker="o", label="Player 2 ratio (under same CCE)")
    plt.xlabel("Cost gap c2 - c1")
    plt.ylabel(r"Ratio: utility / max$_{x\in\mathcal{P}} (x-c_i)f(x)$")
    plt.title(f"CCE maximizing Player 2 (k={k}, c1={c1})")
    plt.grid(True)
    plt.legend()

    if save:
        png2 = os.path.join(out_path, f"{base_tag}_maxP2.png")
        pdf2 = os.path.join(out_path, f"{base_tag}_maxP2.pdf")
        plt.savefig(png2, dpi=dpi, bbox_inches="tight")
        plt.savefig(pdf2, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot:\n  {png2}\n  {pdf2}")

    plt.show()

    return {
        "gaps": gaps,
        "r1_under_max1": r1_under_max1,
        "r2_under_max1": r2_under_max1,
        "r1_under_max2": r1_under_max2,
        "r2_under_max2": r2_under_max2,
        "denom1": denom1,
    }


# ----------------------------
# Example usage
# ----------------------------
if __name__ == "__main__":
    def f(x):
        return np.exp(-x)  # example; replace with your demand

    k = 100
    c1 = 0.0
    c2_vals = np.arange(0.0, 1, 0.049)

    plot_two_panels_vs_gap(k=k, c1=c1, c2_vals=c2_vals, f=f)
