import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import binom, norm
import os
import time

# ----------- USER‑TUNABLE PARAMETERS ---------------------------------
n1, n2 = 1, 1  # number of agents in Group 1, Group 2
c_select = 0.5       # top‑c fraction selected
rho = 0.8            # bias parameter
num_bins = 101        # valuation grid resolution
tol = 1e-4           # L1 convergence tolerance
max_iter = 500       # iteration cap (safety)
epsilon = 0         # tiny effort to break ties
step_constant = 0.1  # relaxation parameter numerator
plot_step = max(1, int(max_iter / 10))  # snapshots every # iterations
# ---------------------------------------------------------------------

start_time = time.time()

# ================================================================
# Piece‑wise linear policy helper
# ================================================================
class PWPolicy:
    def __init__(self, v_grid, e_vals):
        self.v = v_grid
        self.e = e_vals
    def effort(self, vs):
        return np.interp(vs, self.v, self.e)
    def inverse(self, ev):
        idx = np.searchsorted(self.e, ev, side='left')
        if idx >= len(self.v): return self.v[-1]
        if idx == 0:         return self.v[0]
        a, b = self.e[idx-1], self.e[idx]
        va, vb = self.v[idx-1], self.v[idx]
        if b == a: return vb
        return va + (ev-a)*(vb-va)/(b-a)
    def update(self, new_e):
        self.e = new_e

# ================================================================
# Build valuation grids and initial policies
# ================================================================
vals_g1 = np.linspace(0.0, 1.0, num_bins)
vals_g2 = np.linspace(0.0, rho, num_bins)
alpha = n2 / (n1 + n2)
if rho < 1 - c_select / (1 - alpha):
    theta = 1 - c_select / (1 - alpha)
else:
    theta = rho * (1 - c_select) / (rho - rho * alpha + alpha)
# logistic smoothing of initial step
k = 50.0
step_policy = lambda v: theta / (1 + np.exp(-k * (v - theta)))
pi1 = PWPolicy(vals_g1, step_policy(vals_g1))
pi2 = PWPolicy(vals_g2, step_policy(vals_g2))

step_policy_inf = lambda v: np.where(v >= theta, theta, 0.0)
pi1_inf = PWPolicy(vals_g1, step_policy_inf(vals_g1))
pi2_inf = PWPolicy(vals_g2, step_policy_inf(vals_g2))

# use a power-law mapping for compression near zero
power = 1  # try values >1 to increase density near 0
fractions = np.linspace(0, 1, num_bins) ** power

# probability helpers
def exceed_prob(pol, e, max_v):
    if e <= pol.e[0]: return 1.0
    if e > pol.e[-1]: return 0.0
    vs = pol.inverse(e)
    return (max_v - vs) / max_v

def win_prob(e, grp, p1, p2, k):
    m1 = n1 - 1 if grp == 1 else n1
    m2 = n2 - 1 if grp == 2 else n2
    pmf1 = binom.pmf(np.arange(m1 + 1), m1, p1)
    pmf2 = binom.pmf(np.arange(m2 + 1), m2, p2)
    return np.convolve(pmf1, pmf2)[:k].sum()

k_cut = int(np.floor(c_select * (n1 + n2)))
history = []
history_inf = []
stored_iters = [0]
stored_p1 = [pi1.e.copy()]
stored_p2 = [pi2.e.copy()]

# ================================================================
# Iterative best-response with relaxation (Jacobi style)
# ================================================================
for t in range(1, max_iter + 1):
    # relaxation weight
    a = step_constant / np.power(t, 1/2)

    # build combined effort set
    effort_set = np.unique(np.concatenate((pi1.e, pi2.e, fractions)))


    # -------- Group 1 best-response --------
    new1 = np.zeros_like(pi1.e)
    last_e = 0.0
    for i, v in enumerate(vals_g1):
        e = last_e
        p1_e = 1 - v
        p2_e = exceed_prob(pi2, e, rho)
        wp = win_prob(e, 1, p1_e, p2_e, k_cut)
        payoff = wp * v - e
        best_pay, best_e = payoff, e
        for e in effort_set:
            if e < last_e: continue
            p1_e = 1 - v
            p2_e = exceed_prob(pi2, e, rho)
            wp = win_prob(e, 1, p1_e, p2_e, k_cut)
            pay = wp * v - e
            if pay > best_pay:
                best_pay, best_e = pay, e
        last_e = best_e
        new1[i] = best_e
    # relax and update pi1
    r1 = pi1.e + a * (new1 - pi1.e)
    d1 = np.abs(new1 - pi1.e).mean()
    pi1.update(r1)

    # -------- Group 2 best-response --------
    new2 = np.zeros_like(pi2.e)
    last_e = 0.0
    for i, v in enumerate(vals_g2):
        e = last_e
        p2_e = (rho - v) / rho
        p1_e = exceed_prob(pi1, e, 1.0)
        wp = win_prob(e, 2, p1_e, p2_e, k_cut)
        payoff = wp * v - e
        best_pay, best_e = payoff, e
        for e in effort_set:
            if e < last_e: continue
            p2_e = (rho - v) / rho
            p1_e = exceed_prob(pi1, e, 1.0)
            wp = win_prob(e, 2, p1_e, p2_e, k_cut)
            pay = wp * v - e
            if pay > best_pay:
                best_pay, best_e = pay, e
        last_e = best_e
        new2[i] = best_e
    # relax and update pi2
    r2 = pi2.e + a * (new2 - pi2.e)
    d2 = np.abs(new2 - pi2.e).mean()
    pi2.update(r2)


    # record convergence
    history.append((d1, d2))

    # record distance to infinite NE
    d1_inf = np.abs(r1 - pi1_inf.e).mean()
    d2_inf = np.abs(r2 - pi2_inf.e).mean()
    history_inf.append((d1_inf, d2_inf))

    # snapshots
    if t % plot_step == 0 or t == max_iter:
        stored_iters.append(t)
        stored_p1.append(pi1.e.copy())
        stored_p2.append(pi2.e.copy())

    if d1 < tol and d2 < tol:
        print(f"Converged at iteration {t}")
        break

    if t == max_iter:
        print(f"Distance of the output to the infinite NE: {d1_inf:.4f}, {d2_inf:.4f}")

else:
    print("Reached max_iter without convergence.")


# print runtime
print(f"Total runtime: {time.time() - start_time:.6f} seconds")

# ================================================================
# Visualise & Save Results
# ================================================================
output_dir = f"policy_plots_{n1}_{c_select}"
os.makedirs(output_dir, exist_ok=True)


df_final = pd.DataFrame({
    "v_G1": vals_g1,
    "pi1_final": pi1.e,
    "v_G2": np.concatenate([vals_g2, [np.nan] * (len(vals_g1) - len(vals_g2))]),
    "pi2_final": np.concatenate([pi2.e, [np.nan] * (len(vals_g1) - len(pi2.e))])
})

# Write to Excel
excel_path = os.path.join(output_dir, "NE_n"+str(n1)+".xlsx")
with pd.ExcelWriter(excel_path, engine="openpyxl") as writer:
    df_final.to_excel(writer, index=False, sheet_name="Policies")

print(f"Final policy saved to {excel_path}")

# policy snapshots
for idx, it in enumerate(stored_iters):
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(vals_g1, stored_p1[idx], label=r'$s_1$')
    ax.plot(vals_g2, stored_p2[idx], '--', label=r'$s_2$')
    ax.set_xlabel(r'Valuation $v$', fontsize = 20)
    ax.set_ylabel(r'Effort $e$', fontsize = 20)
    ax.set_title(f'Policies at Iteration {it}')
    ax.legend(fontsize = 20)
    plt.tight_layout()
    # plt.show()
    plt.savefig(os.path.join(output_dir, f"policy_iter_{it:03}.png"))
    plt.close()


# convergence plot
plt.figure(figsize=(6, 4))
plt.plot([h[0] for h in history], label='$s_1$')
plt.plot([h[1] for h in history], label='$s_2$')
plt.yscale('log')
plt.xlabel('Iteration', fontsize = 20)
plt.ylabel(r'$\Delta$', fontsize = 20)
plt.title(r'Policy convergence when $n=$'+str(int(2 * n1)), fontsize = 20)
plt.legend(fontsize = 20)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "convergence.png"))
plt.show()
plt.close()

# difference plot
plt.figure(figsize=(6, 4))
plt.plot([h[0] for h in history_inf], label='$s_1$')
plt.plot([h[1] for h in history_inf], label='$s_2$')
plt.yscale('log')
plt.xlabel('Iteration', fontsize = 20)
plt.ylabel(r'$\Delta_{\infty}$', fontsize = 20)
plt.title(r'Distance to $s$ when $n=$'+str(int(2 * n1)), fontsize = 20)
plt.legend(fontsize = 20)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "diff_to_inf.png"))
plt.show()
plt.close()