from ucimlrepo import fetch_ucirepo
import numpy as np
import pandas as pd
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import time
import os

# =========================
# Parameters
# =========================
eta = 0.1
p = 3
C = 1.0

# Regularization for sketched objective
# - If eps_reg_mode="fixed": use eps_reg_fixed
# - If eps_reg_mode="opt_over_pstar": eps_reg = OPT / P*
eps_reg_mode = "fixed"
eps_reg_fixed = 0.1

# Sampling fractions
fractions = (0.10, 0.15, 0.20, 0.25, 0.30)

# Experiment parameters
trials = 10
seed = 0
max_iter_full = 50
max_iter_sketched = 5

# Cache for full-solve outputs (w*, f*, OPT, P*, runtime) per p
CACHE_FILE = f"full_solution_cache_p={p}.npz"

# Output
OUT_CSV = f"compare_sampling_p={p}.csv"


# =========================
# Data loading / preprocessing
# =========================
communities_and_crime = fetch_ucirepo(id=183)
A = communities_and_crime.data.features
b = communities_and_crime.data.targets

b = b.iloc[:, 0].to_numpy(dtype=np.float64)

A = A.drop(["communityname"], axis=1)
A = A.replace("?", np.nan)
A = A.apply(pd.to_numeric, errors="coerce")
A = SimpleImputer(strategy="median").fit_transform(A)

A = StandardScaler(with_mean=True, with_std=True).fit_transform(A).astype(np.float64)


# =========================
# Utilities
# =========================
def grad_abs_p(x, p, delta=1e-12):
    """Subgradient of |x|^p w.r.t x (coordinatewise)."""
    x = np.asarray(x, dtype=float)
    if p == 1:
        g = np.sign(x)
        g = np.where(np.abs(x) < delta, 0.0, g)
        return g
    return p * np.sign(x) * (np.abs(x) ** (p - 1))


def P(A, w, p):
    z = A @ w
    return float(np.sum(np.abs(z) ** p))


def eval_f_1lip_monotone(z, knots, levels):
    """Piecewise-linear interpolation + constant extrapolation."""
    z = np.asarray(z, dtype=float)
    if len(knots) == 0:
        return np.zeros_like(z)
    if len(knots) == 1:
        return np.full_like(z, float(levels[0]))
    return np.interp(z, knots, levels, left=levels[0], right=levels[-1])


def loss_full(A, b, w, knots, levels, p):
    z = A @ w
    fz = eval_f_1lip_monotone(z, knots, levels)
    return float(np.sum(np.abs(fz - b) ** p))


def eps_prime(loss_hat_full, OPT, Pstar):
    denom = OPT + Pstar
    if denom <= 1e-18:
        return 0.0 if loss_hat_full <= OPT else np.inf
    return max(0.0, (loss_hat_full - OPT) / denom)


def summarize_quantiles(x):
    x = np.asarray(x, float)
    return float(np.median(x)), float(np.percentile(x, 25)), float(np.percentile(x, 75))


# =========================
# Fast fitter for f: monotone + 1-Lipschitz + f(0)=0
# =========================
def pava_l2_project(y):
    """L2 projection of y onto nondecreasing sequences via PAV."""
    y = np.asarray(y, float)
    n = len(y)
    levels, weights, starts, ends = [], [], [], []
    for i in range(n):
        levels.append(y[i])
        weights.append(1.0)
        starts.append(i)
        ends.append(i)
        while len(levels) > 1 and levels[-2] > levels[-1]:
            l2, w2, s2, e2 = levels.pop(), weights.pop(), starts.pop(), ends.pop()
            l1, w1, s1, e1 = levels.pop(), weights.pop(), starts.pop(), ends.pop()
            tw = w1 + w2
            merged = (w1 * l1 + w2 * l2) / tw
            levels.append(merged)
            weights.append(tw)
            starts.append(s1)
            ends.append(e2)

    u = np.empty(n, float)
    for lvl, s, e in zip(levels, starts, ends):
        u[s:e + 1] = lvl
    return u


def proj_lipschitz(u, dz):
    """
    L2 projection onto constraints u_{i+1} - u_i <= dz_i
    """
    u = np.asarray(u, float)
    dz = np.asarray(dz, float)
    n = len(u)
    s = np.zeros(n)
    s[1:] = np.cumsum(dz)
    v = u - s
    v_proj = -pava_l2_project(-v)
    return v_proj + s


def proj_monotone_lip(u0, dz, max_iter=50, tol=1e-10):
    """Dykstra alternating projections onto monotone and 1-Lipschitz constraints."""
    u = u0.copy()
    p_corr = np.zeros_like(u)
    q_corr = np.zeros_like(u)

    for _ in range(max_iter):
        u_old = u.copy()

        y = u + p_corr
        u = pava_l2_project(y)
        p_corr = y - u

        y = u + q_corr
        u = proj_lipschitz(y, dz)
        q_corr = y - u

        if np.linalg.norm(u - u_old) < tol:
            break
    return u


def fit_f_monotone_1lip_f0(
    z, y, p=2.0, sample_weight=None,
    anchor_weight_scale=1e6,
    step_size=0.5,
    max_pg_iter=200,
    proj_iter=50,
):
    """
    Fit f in:
      - nondecreasing
      - 1-Lipschitz
      - f(0)=0
    minimizing sum_i w_i |f(z_i) - y_i|^p

    Returns: (knots, levels)
    """
    z = np.asarray(z, float)
    y = np.asarray(y, float)
    n = len(y)

    if sample_weight is None:
        sw = np.ones(n, dtype=float)
    else:
        sw = np.asarray(sample_weight, float)
        if sw.shape[0] != n:
            raise ValueError("sample_weight length mismatch")
        if np.any(sw < 0) or sw.sum() <= 0:
            raise ValueError("Invalid sample_weight")

    anchor_weight = float(sw.max() * anchor_weight_scale) if sw.size > 0 else 1e6
    z_aug = np.concatenate([z, [0.0]])
    y_aug = np.concatenate([y, [0.0]])
    sw_aug = np.concatenate([sw, [anchor_weight]])

    order = np.argsort(z_aug, kind="mergesort")
    zs, ys, ws = z_aug[order], y_aug[order], sw_aug[order]
    dz = np.diff(zs)
    dz = np.maximum(dz, 0.0)

    u = ys.copy()
    ws_sum = float(ws.sum()) + 1e-12

    for _ in range(max_pg_iter):
        e = u - ys
        if p == 1:
            grad = ws * np.sign(e)
        else:
            grad = ws * (p * np.sign(e) * (np.abs(e) ** (p - 1)))

        u_half = u - step_size * grad / ws_sum
        u_new = proj_monotone_lip(u_half, dz, max_iter=proj_iter)

        if np.linalg.norm(u_new - u) < 1e-9:
            u = u_new
            break
        u = u_new

    knots, levels = [], []
    i = 0
    m = len(zs)
    while i < m:
        j = i + 1
        while j < m and zs[j] == zs[i]:
            j += 1
        knots.append(zs[i])
        levels.append(float(np.mean(u[i:j])))
        i = j

    knots = np.array(knots, float)
    levels = np.array(levels, float)

    levels = np.maximum.accumulate(levels)
    for k in range(1, len(levels)):
        levels[k] = min(levels[k], levels[k - 1] + (knots[k] - knots[k - 1]))

    return knots, levels


# =========================
# Isotron (full) and sketched solver
# =========================
def isotron(A, b, p, w0=None, max_iter=200, tol=1e-9, verbose=False):
    n, d = A.shape
    w = np.zeros(d) if w0 is None else w0.copy()

    best_obj = np.inf
    best_w = w.copy()
    best_f = (np.array([0.0]), np.array([0.0]))
    prev_obj = None

    for t in range(max_iter):
        z = A @ w
        knots, levels = fit_f_monotone_1lip_f0(z, b, p=p, sample_weight=None)

        fz = eval_f_1lip_monotone(z, knots, levels)
        r = b - fz
        g = grad_abs_p(r, p=p)

        w_new = w + (eta / n) * (A.T @ g)

        dw = np.linalg.norm(w_new - w)
        obj = float(np.sum(np.abs(r) ** p))

        if obj < best_obj:
            best_obj = obj
            best_w = w_new.copy()
            best_f = (knots.copy(), levels.copy())

        if verbose:
            print(f"[full] iter {t+1:3d}: obj={obj:.6e}, ||dw||={dw:.3e}")

        if dw < tol:
            break
        if prev_obj is not None and abs(obj - prev_obj) < tol:
            break

        prev_obj = obj
        w = w_new

    return best_w, best_f, best_obj


def solve_sketched_regularized(A_s, b_s, sw, p, eps_reg, max_iter=200, tol=1e-9, verbose=False):
    n_s, d = A_s.shape
    w = np.zeros(d)

    best_obj = np.inf
    best_w = w.copy()
    best_f = (np.array([0.0]), np.array([0.0]))

    sw_sum = float(sw.sum()) + 1e-12
    prev_obj = None

    for t in range(max_iter):
        z = A_s @ w

        knots, levels = fit_f_monotone_1lip_f0(z, b_s, p=p, sample_weight=sw)
        fz = eval_f_1lip_monotone(z, knots, levels)
        r = b_s - fz

        g_data = sw * grad_abs_p(r, p=p)
        g_reg = sw * grad_abs_p(z, p=p)

        w_new = w + (eta / sw_sum) * (A_s.T @ (g_data - eps_reg * g_reg))

        dw = np.linalg.norm(w_new - w)

        obj = float(np.sum(sw * (np.abs(r) ** p)) + eps_reg * np.sum(sw * (np.abs(z) ** p)))
        if obj < best_obj:
            best_obj = obj
            best_w = w_new.copy()
            best_f = (knots.copy(), levels.copy())

        if verbose:
            print(f"[sketch] iter {t+1:3d}: obj={obj:.6e}, ||dw||={dw:.3e}")

        if dw < tol:
            break
        if prev_obj is not None and abs(obj - prev_obj) < tol:
            break

        prev_obj = obj
        w = w_new

    return best_w, best_f, best_obj


# =========================
# Sampling: Lewis-style and Uniform baselines
# =========================
def get_sampling_weights(A, p):
    A = np.asarray(A, dtype=float)
    n, d = A.shape
    if p == 2:
        Q, _ = np.linalg.qr(A, mode="reduced")
        w = np.sum(Q * Q, axis=1)
    elif 1 <= p < 4:
        num_iter = 6
        w = np.ones(n) * d / n
        for _ in range(num_iter):
            B = A * np.power(w[:, None], 1 / 2 - 1 / p)
            Q, _ = np.linalg.qr(B, mode="reduced")
            w = np.power(np.sum(Q * Q, axis=1), p / 2) * np.power(w, 1 - p / 2)
    else:
        raise AssertionError("require p in [1,4)")
    w = np.maximum(w, 1e-18)
    return w / np.sum(w)


def sample_lewis(A, m, p, rng):
    n = A.shape[0]
    prob = get_sampling_weights(A, p)
    idx = rng.choice(np.arange(n), size=m, replace=True, p=prob)
    sw = (1.0 / (m * prob[idx])).astype(np.float64)
    return idx, sw


def sample_uniform(A, m, rng):
    n = A.shape[0]
    prob = np.ones(n, dtype=float) / n
    idx = rng.choice(np.arange(n), size=m, replace=True, p=prob)
    sw = (1.0 / (m * prob[idx])).astype(np.float64)  # simplifies to n/m
    return idx, sw


# =========================
# Full-solve caching (compute once, then read)
# =========================
def compute_or_load_full_solution(A, b, p, max_iter=200, cache_file=CACHE_FILE):
    if os.path.exists(cache_file):
        data = np.load(cache_file, allow_pickle=True)
        if int(data["p"]) == int(p):
            w_star = data["w_star"]
            knots_star = data["knots_star"]
            levels_star = data["levels_star"]
            OPT = float(data["OPT"])
            Pstar = float(data["Pstar"])
            full_time = float(data["full_time"])
            return w_star, (knots_star, levels_star), OPT, Pstar, full_time, True

    t0 = time.perf_counter()
    w_star, f_star, _ = isotron(A, b, p=p, max_iter=max_iter, verbose=True)
    full_time = time.perf_counter() - t0

    knots_star, levels_star = f_star
    OPT = loss_full(A, b, w_star, knots_star, levels_star, p=p)
    Pstar = P(A, w_star, p=p)

    np.savez(
        cache_file,
        p=int(p),
        w_star=w_star,
        knots_star=knots_star,
        levels_star=levels_star,
        OPT=float(OPT),
        Pstar=float(Pstar),
        full_time=float(full_time),
    )
    return w_star, (knots_star, levels_star), OPT, Pstar, full_time, False


# =========================
# Main experiment: compare Lewis vs Uniform; median + IQR plots; full runtime included
# =========================
def run_compare_sampling(A, b, p, fractions, trials, seed, eps_reg,
                        eps_reg_mode, max_iter_full, max_iter_sketched):

    rng0 = np.random.default_rng(seed)
    n, d = A.shape

    print("\n[Compute/load full solution for OPT and P* (cached)]")
    w_star, (knots_star, levels_star), OPT, Pstar, full_solve_time, loaded = compute_or_load_full_solution(
        A, b, p=p, max_iter=max_iter_full
    )
    ratio_OPT_over_Pstar = OPT / max(Pstar, 1e-18)

    if eps_reg_mode == "opt_over_pstar":
        eps_reg_used = ratio_OPT_over_Pstar
    elif eps_reg_mode == "fixed":
        eps_reg_used = float(eps_reg)
    else:
        raise ValueError("eps_reg_mode must be 'fixed' or 'opt_over_pstar'")

    print(
        f"\nFULL RESULTS:\n"
        f"  p                = {p}\n"
        f"  OPT              = {OPT:.6e}\n"
        f"  P*               = {Pstar:.6e}\n"
        f"  OPT/P*           = {ratio_OPT_over_Pstar:.6e}\n"
        f"  eps_reg_mode     = {eps_reg_mode}\n"
        f"  eps_reg_used     = {eps_reg_used:.6e}\n"
        f"  full_solve_time   = {full_solve_time:.3f} sec\n"
        f"  cache_status     = {'LOADED' if loaded else 'COMPUTED'}\n"
        f"  cache_file       = {CACHE_FILE}\n"
    )

    rows = []

    for frac in fractions:
        m = int(np.ceil(frac * n))

        # store trial arrays for each method
        metrics = {
            "lewis": {"epsp": [], "relgap": [], "rt": [], "loss": []},
            "uniform": {"epsp": [], "relgap": [], "rt": [], "loss": []},
        }

        for _ in range(trials):
            rng = np.random.default_rng(rng0.integers(0, 2**32 - 1))

            # Lewis
            idx_L, sw_L = sample_lewis(A, m, p, rng)
            A_L, b_L = A[idx_L, :], b[idx_L]
            t1 = time.perf_counter()
            w_hat_L, (knots_L, levels_L), _ = solve_sketched_regularized(
                A_L, b_L, sw_L, p=p, eps_reg=eps_reg_used,
                max_iter=max_iter_sketched, verbose=False
            )
            rt_L = time.perf_counter() - t1
            loss_L = loss_full(A, b, w_hat_L, knots_L, levels_L, p=p)
            metrics["lewis"]["rt"].append(rt_L)
            metrics["lewis"]["loss"].append(loss_L)
            metrics["lewis"]["epsp"].append(eps_prime(loss_L, OPT, Pstar))
            metrics["lewis"]["relgap"].append((loss_L - OPT) / max(OPT, 1e-18))

            # Uniform
            idx_U, sw_U = sample_uniform(A, m, rng)
            A_U, b_U = A[idx_U, :], b[idx_U]
            t2 = time.perf_counter()
            w_hat_U, (knots_U, levels_U), _ = solve_sketched_regularized(
                A_U, b_U, sw_U, p=p, eps_reg=eps_reg_used,
                max_iter=max_iter_sketched, verbose=False
            )
            rt_U = time.perf_counter() - t2
            loss_U = loss_full(A, b, w_hat_U, knots_U, levels_U, p=p)
            metrics["uniform"]["rt"].append(rt_U)
            metrics["uniform"]["loss"].append(loss_U)
            metrics["uniform"]["epsp"].append(eps_prime(loss_U, OPT, Pstar))
            metrics["uniform"]["relgap"].append((loss_U - OPT) / max(OPT, 1e-18))

        # summarize per method
        def qstats(arr):
            return summarize_quantiles(arr)  # median, q25, q75

        L_eps_med, L_eps_q25, L_eps_q75 = qstats(metrics["lewis"]["epsp"])
        L_rg_med, L_rg_q25, L_rg_q75 = qstats(metrics["lewis"]["relgap"])
        L_rt_med, L_rt_q25, L_rt_q75 = qstats(metrics["lewis"]["rt"])
        L_lh_med, L_lh_q25, L_lh_q75 = qstats(metrics["lewis"]["loss"])

        U_eps_med, U_eps_q25, U_eps_q75 = qstats(metrics["uniform"]["epsp"])
        U_rg_med, U_rg_q25, U_rg_q75 = qstats(metrics["uniform"]["relgap"])
        U_rt_med, U_rt_q25, U_rt_q75 = qstats(metrics["uniform"]["rt"])
        U_lh_med, U_lh_q25, U_lh_q75 = qstats(metrics["uniform"]["loss"])

        row = {
            "frac": float(frac),
            "m": int(m),
            "p": float(p),
            "eps_reg_used": float(eps_reg_used),
            "eps_reg_mode": eps_reg_mode,

            "OPT": float(OPT),
            "Pstar": float(Pstar),
            "OPT_over_Pstar": float(ratio_OPT_over_Pstar),
            "full_solve_time_sec": float(full_solve_time),

            # Lewis
            "lewis_loss_median": L_lh_med,
            "lewis_loss_q25": L_lh_q25,
            "lewis_loss_q75": L_lh_q75,
            "lewis_eps_prime_median": L_eps_med,
            "lewis_eps_prime_q25": L_eps_q25,
            "lewis_eps_prime_q75": L_eps_q75,
            "lewis_rel_gap_median": L_rg_med,
            "lewis_rel_gap_q25": L_rg_q25,
            "lewis_rel_gap_q75": L_rg_q75,
            "lewis_runtime_median_sec": L_rt_med,
            "lewis_runtime_q25_sec": L_rt_q25,
            "lewis_runtime_q75_sec": L_rt_q75,
            "lewis_speedup_full_over_median": float(full_solve_time / max(L_rt_med, 1e-12)),

            # Uniform
            "uniform_loss_median": U_lh_med,
            "uniform_loss_q25": U_lh_q25,
            "uniform_loss_q75": U_lh_q75,
            "uniform_eps_prime_median": U_eps_med,
            "uniform_eps_prime_q25": U_eps_q25,
            "uniform_eps_prime_q75": U_eps_q75,
            "uniform_rel_gap_median": U_rg_med,
            "uniform_rel_gap_q25": U_rg_q25,
            "uniform_rel_gap_q75": U_rg_q75,
            "uniform_runtime_median_sec": U_rt_med,
            "uniform_runtime_q25_sec": U_rt_q25,
            "uniform_runtime_q75_sec": U_rt_q75,
            "uniform_speedup_full_over_median": float(full_solve_time / max(U_rt_med, 1e-12)),
        }
        rows.append(row)

        print(
            f"frac={frac:.2f} (m={m:4d}) | "
            f"LEWIS: eps' med={L_eps_med:.2e}, rel_gap med={L_rg_med:.2e}, rt med={L_rt_med:.4f}s | "
            f"UNIF: eps' med={U_eps_med:.2e}, rel_gap med={U_rg_med:.2e}, rt med={U_rt_med:.4f}s | "
            f"full_rt={full_solve_time:.3f}s"
        )

    df = pd.DataFrame(rows)
    df.to_csv(OUT_CSV, index=False)
    print(f"\nSaved table: {OUT_CSV}")

    # Plot helper
    x = df["frac"].to_numpy(float)

    def plot_two_with_iqr(y1_med, y1_q25, y1_q75, y2_med, y2_q25, y2_q75, ylabel, title, label1, label2):
        plt.figure()
        plt.plot(x, y1_med, marker="o", label=label1)
        plt.fill_between(x, y1_q25, y1_q75, alpha=0.2)
        plt.plot(x, y2_med, marker="o", label=label2)
        plt.fill_between(x, y2_q25, y2_q75, alpha=0.2)
        plt.xlabel("sample fraction")
        plt.ylabel(ylabel)
        plt.title(title)
        plt.grid(True)
        plt.legend()

    # eps'
    plot_two_with_iqr(
        df["lewis_eps_prime_median"].to_numpy(float),
        df["lewis_eps_prime_q25"].to_numpy(float),
        df["lewis_eps_prime_q75"].to_numpy(float),
        df["uniform_eps_prime_median"].to_numpy(float),
        df["uniform_eps_prime_q25"].to_numpy(float),
        df["uniform_eps_prime_q75"].to_numpy(float),
        "eps' (median; 25–75%)",
        f"eps' vs sample fraction (p={p}, eps_reg={eps_reg_used:g})",
        "Lewis/QR sampling",
        "Uniform sampling",
    )

    # relative gap
    plot_two_with_iqr(
        df["lewis_rel_gap_median"].to_numpy(float),
        df["lewis_rel_gap_q25"].to_numpy(float),
        df["lewis_rel_gap_q75"].to_numpy(float),
        df["uniform_rel_gap_median"].to_numpy(float),
        df["uniform_rel_gap_q25"].to_numpy(float),
        df["uniform_rel_gap_q75"].to_numpy(float),
        "relative gap (median; 25–75%)",
        f"relative gap vs sample fraction (p={p}, eps_reg={eps_reg_used:g})",
        "Lewis/QR sampling",
        "Uniform sampling",
    )

    # runtime (sketch)
    plot_two_with_iqr(
        df["lewis_runtime_median_sec"].to_numpy(float),
        df["lewis_runtime_q25_sec"].to_numpy(float),
        df["lewis_runtime_q75_sec"].to_numpy(float),
        df["uniform_runtime_median_sec"].to_numpy(float),
        df["uniform_runtime_q25_sec"].to_numpy(float),
        df["uniform_runtime_q75_sec"].to_numpy(float),
        "sketch runtime (sec) (median; 25–75%)",
        f"sketch runtime vs sample fraction (p={p}, eps_reg={eps_reg_used:g})",
        "Lewis/QR sampling",
        "Uniform sampling",
    )

    # speedup (median)
    plt.figure()
    plt.plot(x, df["lewis_speedup_full_over_median"].to_numpy(float), marker="o", label="Lewis/QR sampling")
    plt.plot(x, df["uniform_speedup_full_over_median"].to_numpy(float), marker="o", label="Uniform sampling")
    plt.xlabel("sample fraction")
    plt.ylabel("speedup = full_time / sketch_median_time")
    plt.title(f"speedup vs sample fraction (p={p}, eps_reg={eps_reg_used:g})")
    plt.grid(True)
    plt.legend()

    plt.show()
    return df


# =========================
# Run
# =========================
df = run_compare_sampling(
    A=A, b=b, p=p,
    fractions=fractions,
    trials=trials,
    seed=seed,
    eps_reg=eps_reg_fixed,
    eps_reg_mode=eps_reg_mode,
    max_iter_full=max_iter_full,
    max_iter_sketched=max_iter_sketched,
)

print("\n===== RESULTS TABLE (head) =====")
print(df.head(10))