import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linprog
import os
def max_utility_symmetric_cce(k, n, c, f):
    """
    Symmetric CCE computed with an LP that has k variables:
        pi[j] = Prob(all n players play the same price p_j)

    Prices: P = {1/k, 2/k, ..., 1}
    Profit function on grid: g(p) = (p-c) f(p)

    If all play p, each player's utility is g(p)/n (tie among n).

    CCE (coarse) constraints for constant deviations to any q in P:
        U >= g(q) * Pr[p > q] + g(q)/n * Pr[p = q]
    where p ~ pi is the common recommended price.

    Returns: (U_max, P, pi_star, g)
    """
    P = np.arange(1, k + 1, dtype=float) / k

    # demand on grid (supports vectorized or scalar f)
    try:
        fp = np.asarray(f(P), dtype=float)
    except Exception:
        fp = np.vectorize(f)(P).astype(float)

    g = (P - float(c)) * fp

    # Objective: maximize U = (1/n) * sum_j pi[j] * g[j]
    # linprog minimizes, so minimize -U
    c_obj = -(1.0 / n) * g

    # Equality: sum pi = 1
    A_eq = np.ones((1, k))
    b_eq = np.array([1.0])

    # Inequalities: for each q=P[j]
    # g(q)*(sum_{p>q} pi_p + (1/n)*pi_q) - (1/n)*sum_p pi_p*g(p) <= 0
    A_ub = np.zeros((k, k))
    b_ub = np.zeros(k)

    for j in range(k):
        A_ub[j, :] = -(1.0 / n) * g              # -U part
        if j + 1 < k:
            A_ub[j, j + 1 :] += g[j]            # + g(q) * Pr[p>q]
        A_ub[j, j] += (1.0 / n) * g[j]          # + g(q)/n * Pr[p=q]

    bounds = [(0.0, 1.0) for _ in range(k)]

    res = linprog(
        c=c_obj, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq,
        bounds=bounds, method="highs"
    )
    if not res.success:
        raise RuntimeError(f"LP failed (k={k}, n={n}): {res.message}")

    pi_star = res.x
    U_star = -res.fun
    return U_star, P, pi_star, g


def best_single_price_profit(k, c, f):
    """Computes max_{x in P} (x-c) f(x) on P={1/k,...,1}."""
    P = np.arange(1, k + 1, dtype=float) / k
    try:
        fp = np.asarray(f(P), dtype=float)
    except Exception:
        fp = np.vectorize(f)(P).astype(float)
    g = (P - float(c)) * fp
    idx = int(np.argmax(g))
    return float(g[idx]), float(P[idx])


def plot_ratio_vs_n(k, c, f, n_min=2, n_max=10, eps=1e-12,
                    save=True, out_dir="outputs", basename="ratio_vs_n", dpi=300):
    """
    Plots ratio vs n and saves the plot with a filename that includes k and c (filename-safe).
    """

    def _script_dir():
        try:
            return os.path.dirname(os.path.abspath(__file__))
        except NameError:
            return os.getcwd()

    def _ensure_dir(path):
        os.makedirs(path, exist_ok=True)
        return path

    def _tag_num(x: float) -> str:
        # 0.0 -> "0", 1.25 -> "1p25", -0.5 -> "m0p5"
        s = f"{float(x):g}"
        return s.replace("-", "m").replace(".", "p")

    denom, argmax_p = best_single_price_profit(k=k, c=c, f=f)

    ns = list(range(n_min, n_max + 1))
    ratios = []
    utils = []

    for n in ns:
        U_star, _, _, _ = max_utility_symmetric_cce(k=k, n=n, c=c, f=f)
        ratio = U_star / denom if abs(denom) > eps else np.nan
        utils.append(U_star)
        ratios.append(ratio)

    plt.figure()
    plt.plot(ns, ratios, marker="o")
    plt.xlabel("n (number of players)")
    plt.ylabel(r"Ratio: (max symmetric-CCE utility) / max$_{x\in\mathcal{P}} (x-c)f(x)$")
    plt.title(f"Ratio vs n (k={k}, c={c}, denom at p*={argmax_p:.2f})")
    plt.grid(True)
    plt.ylim(bottom=0)

    if save:
        out_path = _ensure_dir(os.path.join(_script_dir(), out_dir))
        k_tag = _tag_num(k)
        c_tag = _tag_num(c)
        fname_base = f"{basename}_k_{k_tag}_c_{c_tag}_n_{n_min}_to_{n_max}"
        png_path = os.path.join(out_path, f"{fname_base}.png")
        pdf_path = os.path.join(out_path, f"{fname_base}.pdf")
        plt.savefig(png_path, dpi=dpi, bbox_inches="tight")
        plt.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot to:\n  {png_path}\n  {pdf_path}")

    plt.show()


    # Optional: print a small table
    print(f"Denominator max_g = {denom:.6f} attained at p* = {argmax_p:.4f}")
    print(" n |  U_symCCE     |  ratio")
    print("---+-------------+--------")
    for n, U, r in zip(ns, utils, ratios):
        print(f"{n:2d} | {U:11.6f} | {r:6.3f}")

    return ns, ratios, utils, denom, argmax_p


if __name__ == "__main__":
    # Example demand function (edit as you like)
    def f(x):
        return np.exp(-x)

    k = 100
    c = 0.9

    plot_ratio_vs_n(k=k, c=c, f=f, n_min=2, n_max=10)
