"""
genz_eval_d32.py

Cross-domain evaluation (Part A1) for d = 32 using the standard Genz suite:
- Oscillatory
- Product Peak
- Corner Peak
- Gaussian Peak
- Continuous
- Discontinuous

Methods compared (seed-paired):
1) Baseline: SciPy Sobol (scramble=True) with random_base2(m)
2) Evolved: Custom Sobol via C++ (direction numbers + LTM + base-2 digital shift)

Reference integrals are estimated via high-N scrambled Sobol to compute MSE.

Notes on Genz definitions:
We use the canonical forms (see e.g., PyApprox docs) and fix (c, w) deterministically
for reproducibility. The “discontinuous” test uses thresholds u1=u2=0.5.

Run:
  python genz_eval_d32.py --replicates 2000 --nmin 5 --nmax 13 --ref-m 20 --ref-reps 256

Outputs:
- Printed tables of median MSE per N and one-sided Wilcoxon signed-rank p-values
- Optional CSV dumps via --out

"""

from __future__ import annotations
import os
import sys
import math
import csv
import ctypes
import argparse
from dataclasses import dataclass
from typing import Callable, Dict, Tuple, List

import numpy as np
from numpy.typing import NDArray
from scipy.stats import qmc, norm, wilcoxon


# ----------------------------
# Genz suite (d=32 only)
# ----------------------------

def genz_coeffs(d: int) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
    """
    Deterministic (c, w) for reproducible benchmarking.

    c: normalized from c_hat_i = 10^{-15*(i/d)^2}, i=1..d
    w: fixed at 0.5 (common choice in literature/tooling).

    Returns
    -------
    c : (d,)
    w : (d,)
    """
    i = np.arange(1, d + 1, dtype=np.float64)
    c_hat = 10.0 ** (-15.0 * (i / d) ** 2)
    c = c_hat / np.sum(c_hat)
    w = np.full(d, 0.5, dtype=np.float64)
    return c, w


def f_oscillatory(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = cos(2π w1 + sum_i c_i u_i)
    return np.cos(2.0 * np.pi * w[0] + u @ c)


def f_product_peak(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = prod_i ( c_i^{-2} + (u_i - w_i)^2 )^{-1}
    return np.prod(1.0 / (c ** (-2.0) + (u - w) ** 2.0), axis=1)


def f_corner_peak(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = (1 + sum_i c_i u_i)^-(d+1)
    s = 1.0 + u @ c
    d = u.shape[1]
    return s ** (-(d + 1))


def f_gaussian_peak(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = exp( - sum_i c_i^2 (u_i - w_i)^2 )
    return np.exp(-np.sum((c ** 2.0) * (u - w) ** 2.0, axis=1))


def f_continuous(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = exp( - sum_i c_i |u_i - w_i| )
    return np.exp(-np.sum(c * np.abs(u - w), axis=1))


def f_discontinuous(u: NDArray[np.float64], c: NDArray[np.float64], w: NDArray[np.float64]) -> NDArray[np.float64]:
    # f(u) = 0 if u1>u1* or u2>u2*; else exp(sum_i c_i u_i). We use u*=(0.5, 0.5).
    mask = (u[:, 0] <= 0.5) & (u[:, 1] <= 0.5)
    val = np.zeros(u.shape[0], dtype=np.float64)
    val[mask] = np.exp((u[mask] @ c))
    return val


GENZ_FUNCS: Dict[str, Callable[[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]],
                                NDArray[np.float64]]] = {
    "oscillatory": f_oscillatory,
    "product-peak": f_product_peak,
    "corner-peak": f_corner_peak,
    "gaussian-peak": f_gaussian_peak,
    "continuous": f_continuous,
    "discontinuous": f_discontinuous,
}


# ----------------------------
# Evolved (C++) Sobol generator (LTM + digital shift)
# ----------------------------

class DimensionParameters(ctypes.Structure):
    _fields_ = [
        ("s", ctypes.c_int),
        ("a", ctypes.c_uint32),
        ("m_i", ctypes.c_uint32 * 30)
    ]


def _load_sobol_lib() -> ctypes.CDLL:
    lib_path = os.environ.get(
        "SOBOL_SO_PATH",
        "openevolve-star-discrepancy/examples/qmc/sobol_generator.so"
    )
    lib = ctypes.CDLL(lib_path)
    lib.generate_sobol_points.argtypes = [
        ctypes.c_int,                                 # n_points
        ctypes.c_int,                                 # n_dimensions
        ctypes.POINTER(DimensionParameters),          # input_sobol_params
        ctypes.POINTER(ctypes.c_double),              # output_points
        ctypes.POINTER(ctypes.c_uint32),              # ltm_elements_flat
        ctypes.POINTER(ctypes.c_uint32)               # digital_shifts
    ]
    lib.generate_sobol_points.restype = None
    return lib


def construct_sobol_sequence_evolved() -> List[Dict[str, object]]:
    """
    Your evolved direction numbers (dimensions 2..32). Dimension 1 is implicit.
    This mirrors the structure you used previously.
    """
    params = [
        {'s': 1, 'a': 0, 'm_i': [1]},  # Dim 2
        {'s': 2, 'a': 1, 'm_i': [1, 3]},  # Dim 3
        {'s': 3, 'a': 1, 'm_i': [1, 3, 5]},  # Dim 4 (changed in your paper)
        {'s': 3, 'a': 2, 'm_i': [1, 3, 7]},  # Dim 5 (changed in your paper)
        {'s': 4, 'a': 1, 'm_i': [1, 1, 3, 7]},  # Dim 6 (changed in your paper)
        {'s': 4, 'a': 4, 'm_i': [1, 3, 5, 13]},  # Dim 7
        {'s': 5, 'a': 2, 'm_i': [1, 1, 5, 5, 17]},  # Dim 8
        {'s': 5, 'a': 4, 'm_i': [1, 1, 5, 5, 5]},  # Dim 9
        {'s': 5, 'a': 7, 'm_i': [1, 1, 7, 11, 19]},  # Dim 10
        {'s': 5, 'a': 11, 'm_i': [1, 1, 5, 1, 1]},  # Dim 11
        {'s': 5, 'a': 13, 'm_i': [1, 1, 1, 3, 11]},  # Dim 12
        {'s': 5, 'a': 14, 'm_i': [1, 3, 5, 5, 31]},  # Dim 13
        {'s': 6, 'a': 1, 'm_i': [1, 3, 3, 9, 7, 49]},  # Dim 14
        {'s': 6, 'a': 13, 'm_i': [1, 1, 1, 15, 21, 21]},  # Dim 15
        {'s': 6, 'a': 16, 'm_i': [1, 3, 1, 13, 27, 49]},  # Dim 16
        {'s': 6, 'a': 19, 'm_i': [1, 1, 1, 15, 7, 5]},  # Dim 17
        {'s': 6, 'a': 22, 'm_i': [1, 3, 1, 15, 13, 25]},  # Dim 18
        {'s': 6, 'a': 25, 'm_i': [1, 1, 5, 5, 19, 61]},  # Dim 19
        {'s': 7, 'a': 1, 'm_i': [1, 3, 7, 11, 23, 15, 103]},  # Dim 20
        {'s': 7, 'a': 4, 'm_i': [1, 3, 7, 13, 13, 15, 69]},  # Dim 21
        {'s': 7, 'a': 7, 'm_i': [1, 1, 3, 13, 7, 35, 63]},  # Dim 22
        {'s': 7, 'a': 8, 'm_i': [1, 3, 5, 9, 1, 25, 53]},  # Dim 23
        {'s': 7, 'a': 14, 'm_i': [1, 3, 1, 13, 9, 35, 107]},  # Dim 24
        {'s': 7, 'a': 19, 'm_i': [1, 3, 1, 5, 27, 61, 31]},  # Dim 25
        {'s': 7, 'a': 21, 'm_i': [1, 1, 5, 11, 19, 41, 61]},  # Dim 26
        {'s': 7, 'a': 28, 'm_i': [1, 3, 5, 3, 3, 13, 69]},  # Dim 27
        {'s': 7, 'a': 31, 'm_i': [1, 1, 7, 13, 1, 19, 1]},  # Dim 28
        {'s': 7, 'a': 32, 'm_i': [1, 3, 7, 5, 13, 19, 59]},  # Dim 29
        {'s': 7, 'a': 37, 'm_i': [1, 1, 3, 9, 25, 29, 41]},  # Dim 30
        {'s': 7, 'a': 41, 'm_i': [1, 3, 5, 13, 23, 1, 55]},  # Dim 31
        {'s': 7, 'a': 42, 'm_i': [1, 3, 7, 3, 13, 59, 17]},  # Dim 32
    ]
    return params


def sobol_points_cpp(
    sobol_params: List[Dict[str, object]],
    n_points: int,
    n_dimensions: int,
    ltm: NDArray[np.uint32],
    digital_shifts: NDArray[np.uint32],
    lib: ctypes.CDLL
) -> NDArray[np.float64]:
    """
    Generate Sobol points via C++ with LTM and digital shifts.

    sobol_params: list for dimensions 2..n (length = n-1)
    ltm: shape (n, 30, 30), lower-triangular 0/1 with unit diagonal
    digital_shifts: shape (n,), uint32 digital (base-2) shifts
    """
    assert n_dimensions >= 2
    assert len(sobol_params) == n_dimensions - 1

    ParamsArrayType = DimensionParameters * (n_dimensions - 1)
    ctypes_params_array = ParamsArrayType()

    for i, p in enumerate(sobol_params):
        s_val = int(p["s"])
        a_val = int(p["a"])
        m_i_list = list(p["m_i"])
        ctypes_params_array[i].s = s_val
        ctypes_params_array[i].a = ctypes.c_uint32(a_val)
        for j in range(30):
            ctypes_params_array[i].m_i[j] = ctypes.c_uint32(m_i_list[j] if j < s_val else 0)

    # Flatten LTM for C
    ltm = np.asarray(ltm, dtype=np.uint32, order="C")
    assert ltm.shape == (n_dimensions, 30, 30)
    ltm_flat = ltm.reshape(-1)

    # Digital shifts
    digital_shifts = np.asarray(digital_shifts, dtype=np.uint32, order="C")
    assert digital_shifts.shape == (n_dimensions,)

    # Output buffer
    out = (ctypes.c_double * (n_points * n_dimensions))()

    lib.generate_sobol_points(
        int(n_points),
        int(n_dimensions),
        ctypes_params_array,
        out,
        (ctypes.c_uint32 * ltm_flat.size)(*ltm_flat),
        (ctypes.c_uint32 * n_dimensions)(*digital_shifts)
    )

    pts = np.ctypeslib.as_array(out).reshape((n_points, n_dimensions))
    return pts


def sample_ltm_and_shifts(rng: np.random.Generator, d: int) -> Tuple[NDArray[np.uint32], NDArray[np.uint32]]:
    """Random lower-triangular 0/1 matrices (unit diagonal) and base-2 digital shifts."""
    ltm_bits = rng.integers(0, 2, size=(d, 30, 30), dtype=np.uint32)
    # force lower-triangular
    for j in range(d):
        ltm_bits[j] = np.tril(ltm_bits[j])
        for k in range(30):
            ltm_bits[j, k, k] = 1  # unit diag
    # 30-bit shifts -> uint32
    shift_bits = rng.integers(0, 2, size=(d, 30), dtype=np.uint32)
    powers = (2 ** np.arange(30, dtype=np.uint32)).reshape(30, 1)  # (30,1) to broadcast
    # safer dot because uint32 can overflow in python int? Stay in python int then cast back
    shifts = (shift_bits @ (2 ** np.arange(30, dtype=np.uint64))).astype(np.uint32)
    return ltm_bits, shifts


# ----------------------------
# Baseline generator (SciPy)
# ----------------------------

def sobol_points_scipy(n_points: int, d: int, seed: int) -> NDArray[np.float64]:
    """SciPy Sobol with Owen scramble; points in [0,1)."""
    eng = qmc.Sobol(d=d, scramble=True, seed=seed)
    # n_points must be power of two; we call random_base2 with m so that 2^m = n_points
    m = int(round(math.log2(n_points)))
    if 2 ** m != n_points:
        raise ValueError("n_points must be a power of two.")
    return eng.random_base2(m=m)


# ----------------------------
# Integration helpers
# ----------------------------

def integrate_mc(fun: Callable[[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]], NDArray[np.float64]],
                 pts: NDArray[np.float64],
                 c: NDArray[np.float64],
                 w: NDArray[np.float64]) -> float:
    vals = fun(pts, c, w)
    return float(np.mean(vals))


def reference_integral(fun: Callable, d: int, c: NDArray[np.float64], w: NDArray[np.float64],
                       m_ref: int = 20, reps: int = 256, seed0: int = 12345) -> float:
    """
    High-precision reference via scrambled Sobol (SciPy):
    N_ref = 2^m_ref samples, averaged over `reps` independent scrambles (paired by seed).
    """
    N_ref = 2 ** m_ref
    est = np.zeros(reps, dtype=np.float64)
    for k in range(reps):
        pts = sobol_points_scipy(N_ref, d, seed=seed0 + k)
        est[k] = integrate_mc(fun, pts, c, w)
    return float(np.mean(est))


# ----------------------------
# Evaluation loop
# ----------------------------

@dataclass
class EvalConfig:
    d: int = 32
    nmin: int = 5           # 2^5 = 32
    nmax: int = 13          # 2^13 = 8192
    replicates: int = 2000  # number of randomizations per N
    m_ref: int = 20         # 2^20 ref points per replicate
    ref_reps: int = 256     # number of independent scrambles for reference
    seed0: int = 12345
    out_csv: str | None = None  # optional: path to write per-sample MSEs


def run_evaluation(cfg: EvalConfig) -> None:
    d = cfg.d
    c, w = genz_coeffs(d)
    lib = _load_sobol_lib()
    evolved_params = construct_sobol_sequence_evolved()

    integrands = ["oscillatory", "product-peak", "corner-peak", "gaussian-peak", "continuous", "discontinuous"]

    # Precompute references per integrand (same across N)
    print("Computing reference integrals ...")
    refs: Dict[str, float] = {}
    for name in integrands:
        f = GENZ_FUNCS[name]
        ref_val = reference_integral(f, d, c, w, m_ref=cfg.m_ref, reps=cfg.ref_reps, seed0=cfg.seed0 + 7777)
        refs[name] = ref_val
        print(f"  {name:14s} ref ≈ {ref_val:.8e}")

    # Optional CSV output
    csv_writer = None
    if cfg.out_csv:
        csv_file = open(cfg.out_csv, "w", newline="")
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(["integrand", "N", "rep", "method", "estimate", "squared_error"])

    print("\n=== Results: median MSE over replicates, per N (d=32) ===")
    for name in integrands:
        f = GENZ_FUNCS[name]
        ref_val = refs[name]

        print(f"\n[{name}]")
        print(f"{'N':>6} | {'MSE_baseline':>14} | {'MSE_evolved':>13} | {'Δ% (evo vs base)':>16} | {'Wilcoxon p(one-sided)':>24}")
        print("-" * 86)

        for m in range(cfg.nmin, cfg.nmax + 1):
            N = 2 ** m
            mse_base = np.empty(cfg.replicates, dtype=np.float64)
            mse_evol = np.empty(cfg.replicates, dtype=np.float64)

            for rep in range(cfg.replicates):
                # Pair seeds across methods
                seed = cfg.seed0 + rep

                # Baseline: SciPy Sobol scramble
                pts_base = sobol_points_scipy(N, d, seed=seed)
                est_base = integrate_mc(f, pts_base, c, w)
                mse_base[rep] = (est_base - ref_val) ** 2

                # Evolved: C++ Sobol with LTM + digital shift
                rng = np.random.default_rng(seed=seed)
                ltm, shifts = sample_ltm_and_shifts(rng, d)
                pts_evo = sobol_points_cpp(evolved_params, N, d, ltm, shifts, lib=lib)
                est_evo = integrate_mc(f, pts_evo, c, w)
                mse_evol[rep] = (est_evo - ref_val) ** 2

                if csv_writer is not None:
                    csv_writer.writerow([name, N, rep, "baseline", est_base, mse_base[rep]])
                    csv_writer.writerow([name, N, rep, "evolved", est_evo, mse_evol[rep]])

            # Medians (robust to outliers) for table; tests use paired samples
            med_base = float(np.mean(mse_base))
            med_evo = float(np.mean(mse_evol))
            rel_delta = 100.0 * (med_evo - med_base) / (med_base + 1e-300)

            # One-sided Wilcoxon signed-rank, H1: MSE_evolved < MSE_baseline
            stat, p_two_sided = wilcoxon(mse_evol, mse_base, alternative="less", zero_method="wilcox")
            p_val = p_two_sided  # already one-sided due to 'alternative="less"'

            print(f"{N:6d} | {med_base:14.6e} | {med_evo:13.6e} | {rel_delta:16.2f}% | {p_val:24.3e}")

    if csv_writer is not None:
        csv_file.close()


# ----------------------------
# CLI
# ----------------------------

def parse_args() -> EvalConfig:
    ap = argparse.ArgumentParser(description="Genz d=32 evaluation: baseline vs evolved (seed-paired)")
    ap.add_argument("--replicates", type=int, default=10000, help="randomizations per N (default 10000)")
    ap.add_argument("--nmin", type=int, default=5, help="min m (2^m points), default 5 => 32")
    ap.add_argument("--nmax", type=int, default=13, help="max m (2^m points), default 13 => 8192")
    ap.add_argument("--ref-m", type=int, default=21, help="reference points 2^m (default 21)")
    ap.add_argument("--ref-reps", type=int, default=1000, help="reference scrambles (default 1000)")
    ap.add_argument("--seed0", type=int, default=12345, help="base seed")
    ap.add_argument("--out", type=str, default=None, help="optional CSV path to dump per-replicate results")
    args = ap.parse_args()
    return EvalConfig(
        nmin=args.nmin,
        nmax=args.nmax,
        replicates=args.replicates,
        m_ref=args.ref_m,
        ref_reps=args.ref_reps,
        seed0=args.seed0,
        out_csv=args.out
    )


if __name__ == "__main__":
    cfg = parse_args()
    # Sanity: enforce d=32 per your request
    cfg.d = 32
    run_evaluation(cfg)
