import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick 
import mpmath as mp

# ---------- Core helpers ----------

def f_product(x: np.ndarray) -> float:
    return float(np.prod(x))

def grad_f_norm_sq_at(x: np.ndarray) -> float:
    P = np.prod(x)
    return float(np.sum((P / x) ** 2))

def balanced_solution(L: int, y: float) -> np.ndarray:
    if y <= 0:
        raise ValueError("This implementation assumes y>0 for a real balanced solution.")
    a = y ** (1.0 / L)
    return np.full(L, a, dtype=np.float64)

# ---------- Projection onto M = {z: prod(z) = y} ----------

def project_positive_orthant_bisection(x: np.ndarray, y: float,
                                       tol: float = 1e-12, max_iter: int = 200) -> np.ndarray:
    """
    Exact Euclidean projection for y>0 and x_i>0:
      z_i(α) = (x_i ± sqrt(x_i^2 - 4 α y)) / 2
    Branch selection:
      - If prod(x) < y: need to increase product -> α < 0, use PLUS root.
      - If prod(x) >= y:
          Let P1 = prod(x) / 2^L (the product when every z_i = x_i/2).
          * If y >= P1: use PLUS root with α in [0, α_max] (product decreases with α).
          * If y <  P1: use MINUS root with α in [0, α_max] (product increases with α).
    In each case the mapping α -> prod(z(α)) is monotone, so bisection is valid.
    """
    assert np.all(x > 0) and y > 0
    L = x.size
    P = float(np.prod(x))
    P_half = P / (2.0 ** L)
    alpha_max = float(np.min(x * x) / (4.0 * y)) - 1e-18  # keep radicand >= 0

    def phi(alpha, sgn):
        rad = x * x - 4.0 * alpha * y
        if np.any(rad < 0):
            return np.nan
        z = 0.5 * (x + sgn * np.sqrt(rad))
        return float(np.prod(z))

    # Choose branch and initial bracket
    if P < y:
        # Need α < 0, PLUS branch; phi decreases from +∞ (α->-∞) down to P at α=0
        sgn = +1
        lo, hi = -1.0, 0.0
        # expand 'lo' until phi(lo) >= y
        while True:
            val = phi(lo, sgn)
            if np.isfinite(val) and val >= y:
                break
            lo *= 2.0
    else:
        if y >= P_half:
            # PLUS branch, α in [0, α_max], phi decreases from P to P_half
            sgn = +1
            lo, hi = 0.0, alpha_max
        else:
            # MINUS branch, α in [0, α_max], phi increases from 0 to P_half
            sgn = -1
            lo, hi = 0.0, alpha_max

    # Bisection on monotone phi
    vlo = phi(lo, sgn); vhi = phi(hi, sgn)
    # Ensure y is between endpoints; if not (very rare due to roundoff), nudge bounds
    if not (np.isfinite(vlo) and np.isfinite(vhi) and ((vlo - y) * (vhi - y) <= 0)):
        # tiny nudge
        eps = 1e-18
        lo = max(lo, -1e100); hi = min(hi, alpha_max - eps)
        vlo = phi(lo, sgn); vhi = phi(hi, sgn)

    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        vm = phi(mid, sgn)
        if not np.isfinite(vm):
            hi = mid
            continue
        if abs(vm - y) <= max(1e-14, 1e-12 * y):
            alpha = mid
            break
        # determine monotonicity direction
        if sgn == +1:
            # plus branch: phi decreases with λ
            if vm > y:
                lo = mid
            else:
                hi = mid
        else:
            # minus branch: phi increases with λ
            if vm < y:
                lo = mid
            else:
                hi = mid
    else:
        alpha = 0.5 * (lo + hi)

    rad = np.clip(x * x - 4.0 * alpha * y, 0.0, None)
    z = 0.5 * (x + sgn * np.sqrt(rad))
    return z

def project_to_product_manifold(x: np.ndarray,
                                y: float,
                                z_prev: np.ndarray | None = None,
                                tol: float = 1e-15,
                                max_iter: int = 50) -> np.ndarray:
    """
    Euclidean projection of x onto {z: prod(z)=y}.
    - If y==0: set the smallest |coordinate| to 0 (exact).
    - If y>0 and x>0 elementwise: exact positive-orthant projection via bisection.
    - Otherwise: simple uniform scaling fallback to satisfy prod(z)=y.
    """
    x = x.astype(np.float64, copy=True)
    L = x.size
    eps = 1e-15

    if y == 0.0:
        z = x.copy()
        idx = np.argmin(np.abs(z))
        z[idx] = 0.0
        return z

    if y > 0 and np.all(x > 0):
        try:
            return project_positive_orthant_bisection(x, y, tol=tol, max_iter=300)
        except Exception:
            pass  # fall through to fallback

    # Fallback: uniform scaling (robust and exact when no zeros)
    P = np.prod(x)
    if abs(P) < eps:
        # nudge tiny entries so scaling works
        x = np.where(np.abs(x) < eps, np.sign(x) * eps, x)
        P = np.prod(x)
    alpha = (y / P) ** (1.0 / L)
    z = alpha * x
    return z

_log10 = np.log10

def _pow10_fmt(x, _pos=None):
    return rf"$10^{{{_log10(x):.2g}}}$"

def _two_log_ticks(ax, ydata):
    ydata = np.asarray(ydata, float)
    ydata = ydata[ydata > 0]
    if ydata.size == 0:
        return
    ymin, ymax = ydata.min(), ydata.max()
    if ymin == ymax:
        ymin *= 0.8
        ymax *= 1.2
    ticks = [ymin, ymax]
    ax.yaxis.set_major_locator(mtick.FixedLocator(ticks))
    ax.yaxis.set_major_formatter(mtick.FuncFormatter(_pow10_fmt))
    ax.yaxis.set_minor_locator(mtick.NullLocator())

# ---------- Simulation (no plotting) ----------

def simulate_experiment(L=5, y=2.0, steps=200, sigma=1e-2, seed=0, mode=-1):
    """
    Run GD on 0.5 * (f(x)-y)^2 and return time series without plotting.
    Returns dict with keys: times, par_errors, perp_norms, lam_gaps, meta.
    """
    if y <= 0:
        raise ValueError("This script assumes y>0.")
    rng = np.random.default_rng(seed)

    # Balanced solution and step size
    x_star = balanced_solution(L, y)
    df_norm_sq_star = grad_f_norm_sq_at(x_star)
    gamma = 5e-3
    eta = 2.0 / df_norm_sq_star + mode * gamma

    # Init near balanced
    noise = sigma * rng.standard_normal(L)
    x = x_star * (1.0 + noise)

    par_errors, perp_norms, lam_gaps, times = [], [], [], []

    z_prev = x_star.copy()
    eps = 1e-15
    tiny = 1e-300

    for t in range(steps):
        P = f_product(x)
        grad = (P - y) * (P / x)
        x = x - eta * grad
        x[np.abs(x) < eps] = np.sign(x[np.abs(x) < eps]) * eps

        z = project_to_product_manifold(x, y, z_prev=z_prev)
        z_prev = z

        par_err = np.linalg.norm(z - x_star)
        perp = x - z
        perp_norm = np.linalg.norm(perp)
        lam_t = float(np.sum((y / z) ** 2))  # since prod(z)=y on the manifold

        par_errors.append(max(par_err, tiny))
        perp_norms.append(max(perp_norm, tiny))
        lam_gaps.append(max(lam_t - df_norm_sq_star, tiny))
        times.append(t)

    return {
        "times": np.array(times),
        "par_errors": np.array(par_errors),
        "perp_norms": np.array(perp_norms),
        "lam_gaps": np.array(lam_gaps),
        "meta": {"L": L, "y": y, "steps": steps, "sigma": sigma, "seed": seed, "mode": mode, "eta": eta}
    }

# ---------- multi-seed overlay plot ----------

def run_multi_experiments(L=5, y=2.0, steps=200, sigma=1e-2, seeds=(0,1,2,3,4), mode=-1):
    """
    Run the experiment for each seed in `seeds` and plot all runs together
    on a single row of 3 panels. Color/legend encodes seeds.
    """
    fs = 18
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Collect to set nice end-point ticks across all curves
    all_par, all_perp, all_lam = [], [], []

    # Let matplotlib manage the color cycle; add clear legend labels
    for seed in seeds:
        d = simulate_experiment(L=L, y=y, steps=steps, sigma=sigma, seed=seed, mode=mode)
        t = d["times"]
        par = d["par_errors"]
        perp = d["perp_norms"]
        lamg = d["lam_gaps"]

        axes[0].plot(t, par, linewidth=3, label=f"seed {seed}")
        axes[1].plot(t, perp, linewidth=3, label=f"seed {seed}")
        axes[2].plot(t, lamg, linewidth=3, label=f"seed {seed}")

        all_par.append(par)
        all_perp.append(perp)
        all_lam.append(lamg)

    # Log scales and titles
    axes[0].set_yscale('log')
    axes[1].set_yscale('log')
    axes[2].set_yscale('log')

    axes[0].set_title(r"$\|\theta_t^{\parallel}-\theta^{\parallel}_*\|$", fontsize=fs)
    axes[1].set_title(r"$|\theta_t^{\perp}|$", fontsize=fs)
    axes[2].set_title(r"$\lambda(\theta^{\parallel}_t) - \lambda(\theta^{\parallel}_*)$", fontsize=fs)

    for ax in axes:
        ax.set_xlabel("Iteration (t)", fontsize=fs)
        ax.tick_params(labelsize=fs)
        ax.grid(True)
        # ax.legend(loc="best", fontsize=fs-4, ncol=1, frameon=True)

    # Use min/max over all runs for the two-tick helper
    all_par = np.concatenate(all_par) if all_par else np.array([1.0])
    all_perp = np.concatenate(all_perp) if all_perp else np.array([1.0])
    all_lam = np.concatenate(all_lam) if all_lam else np.array([1.0])

    _two_log_ticks(axes[0], all_par)
    _two_log_ticks(axes[1], all_perp)
    _two_log_ticks(axes[2], all_lam)

    fig.tight_layout()
    name = { -1: "subcrit", 0: "crit", 1: "supcrit" }.get(mode, "mode")
    seeds_str = "-".join(map(str, seeds))
    fig.savefig(f'{steps}_{name}_seeds[{seeds_str}]_new.pdf')
    plt.show()

if __name__ == "__main__":
    # Example single-seed (original behavior)
    # run_experiment(L=5, y=1, steps=1000, sigma=1e-1, seed=0, mode=1)

    """
    We used seeds 0,1,2,3,4 for the critical and supercritical experiments. Seed 5 was used in place of 0 for the subcritical experiment, since
    seed 0 did not start out with sufficiently large sharpness to be a valid demonstration of our theorem.
    """

    # Example multi-seed overlay:
    run_multi_experiments(L=5, y=1, steps=10000, sigma=1e-1, seeds=[0, 1, 2, 3, 4], mode=1)
