import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linprog
import os

def max_utility_symmetric_cce_bertrand(k, c, f):
    """
    n=2, symmetric costs (both have marginal cost c).
    Symmetric CCE LP with k variables pi[j] = Prob(both play price p_j) (perfectly correlated / tie).

    Prices P = {1/k, 2/k, ..., 1}.

    Objective: maximize U = 0.5 * sum_j pi[j] * g(p_j),
    where g(p) = (p-c)*f(p).

    CCE constraints (constant deviations): for every q in P,
        U >= g(q) * (Pr[opp > q] + 0.5 Pr[opp = q])
           = g(q) * (sum_{p>q} pi[p] + 0.5*pi[q])
    """
    P = np.arange(1, k + 1, dtype=float) / k

    # Allow f to accept either numpy arrays or scalars
    try:
        fp = np.asarray(f(P), dtype=float)
    except Exception:
        fp = np.vectorize(f)(P).astype(float)

    g = (P - float(c)) * fp

    # linprog minimizes, so minimize -U
    c_obj = -0.5 * g

    # sum pi = 1
    A_eq = np.ones((1, k))
    b_eq = np.array([1.0])

    # For each q=P[j]:
    # g(q)*(sum_{p>q} pi_p + 0.5*pi_q) - 0.5*sum_p pi_p*g(p) <= 0
    A_ub = np.zeros((k, k))
    b_ub = np.zeros(k)

    for j in range(k):
        A_ub[j, :] = -0.5 * g
        if j + 1 < k:
            A_ub[j, j + 1 :] += g[j]
        A_ub[j, j] += 0.5 * g[j]

    bounds = [(0.0, 1.0) for _ in range(k)]

    res = linprog(
        c=c_obj, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq,
        bounds=bounds, method="highs"
    )
    if not res.success:
        raise RuntimeError(f"LP failed for k={k}: {res.message}")

    U_star = -res.fun
    return U_star, P, res.x, g


def best_single_price_profit(k, c, f):
    """
    Computes max_{x in P} (x-c) f(x) over the discrete grid P={1/k,...,1}.
    Returns (best_value, argmax_price).
    """
    P = np.arange(1, k + 1, dtype=float) / k
    try:
        fp = np.asarray(f(P), dtype=float)
    except Exception:
        fp = np.vectorize(f)(P).astype(float)
    g = (P - float(c)) * fp
    idx = int(np.argmax(g))
    return float(g[idx]), float(P[idx])


def plot_ratio_vs_k(c, f, ks=range(10, 101, 10), eps=1e-12,
                    save=True, out_dir="outputs_exp_e", basename="ratio_vs_k", dpi=300):
    """
    Plots:
        ratio(k) = (max symmetric-CCE utility at k) / (max_{x in P} (x-c)f(x))

    Saves the plot with a filename that includes c (filename-safe).
    """

    def _script_dir():
        try:
            return os.path.dirname(os.path.abspath(__file__))
        except NameError:
            return os.getcwd()

    def _ensure_dir(path):
        os.makedirs(path, exist_ok=True)
        return path

    def _tag_num(x: float) -> str:
        # 0.0 -> "0", 1.25 -> "1p25", -0.5 -> "m0p5"
        s = f"{float(x):g}"
        return s.replace("-", "m").replace(".", "p")

    ks_list = list(ks)
    ratios = []
    cce_utils = []
    best_profits = []

    for k in ks_list:
        U_star, _, _, _ = max_utility_symmetric_cce_bertrand(k=k, c=c, f=f)
        best_val, _ = best_single_price_profit(k=k, c=c, f=f)

        denom = best_val
        ratio = U_star / denom if abs(denom) > eps else np.nan

        cce_utils.append(U_star)
        best_profits.append(best_val)
        ratios.append(ratio)

    plt.figure()
    plt.plot(ks_list, ratios, marker="o")

    # Dotted reference line at 1/e
    plt.axhline(y=1/np.e, linestyle=":", linewidth=1.5, label=r"$1/e$")

    plt.xlabel("k (price grid size; P = {1/k, ..., 1})")
    plt.ylabel(r"Ratio:  (max symmetric-CCE utility) / max$_{x\in\mathcal{P}} (x-c)f(x)$")
    plt.title(f"Ratio vs k (c={c})")
    plt.grid(True)
    plt.legend()

    if save:
        out_path = _ensure_dir(os.path.join(_script_dir(), out_dir))
        c_tag = _tag_num(c)
        fname_base = f"{basename}_c_{c_tag}"
        png_path = os.path.join(out_path, f"{fname_base}.png")
        pdf_path = os.path.join(out_path, f"{fname_base}.pdf")
        plt.savefig(png_path, dpi=dpi, bbox_inches="tight")
        plt.savefig(pdf_path, dpi=dpi, bbox_inches="tight")
        print(f"Saved plot to:\n  {png_path}\n  {pdf_path}")

    plt.show()

    return ks_list, ratios, cce_utils, best_profits


if __name__ == "__main__":
    # Example demand function: f(x)=1-x (you can replace this)
    def f(x):
        return np.exp(-x)

    c = 0.9  # set marginal cost here

    ks, ratios, cce_utils, best_profits = plot_ratio_vs_k(c=c, f=f, ks=range(10, 101, 1))

    # Optional: print a small table
    #print(" k |  U_CCE      |  max_g      |  ratio")
    #print("---+------------+------------+--------")
    #for k, U, gmax, r in zip(ks, cce_utils, best_profits, ratios):
    #    print(f"{k:2d} | {U:10.6f} | {gmax:10.6f} | {r:6.3f}")
