#!/usr/bin/env python3
"""
baseline_e1_e2.py

Time-calibrated SciPy baselines for E1 & E2 (2D and 3D star discrepancy).

- Uses your exact star_discrepancy() implementation (Numba-accelerated) so
  results are directly comparable to Phase 1 / Phase 2 evaluators.
- Runs SLSQP starting from various initial point sets:
    * fibonacci  : Fibonacci / lattice-style baseline
    * sobol      : Sobol sequence baseline
    * mpmc       : placeholder (you fill: load MPMC points)
    * clement    : placeholder (you fill: load Clément et al. points)
    * phase1     : placeholder (you fill: load LLM Phase-1 outputs)
- No family cycling, no jitter: one optimization per (N, dim, family, replicate).
- `maxiter` is chosen automatically to fit all runs into a given wall-clock budget
  (default 24 hours) by timing star_discrepancy on 10 random examples.

Usage examples
--------------

# E1: 2D, Ns=16 40 60 100, 4 families, 3 replicates, 24h budget
python baseline_e1_e2.py --dim 2 --N 16 40 60 100 \
    --families fibonacci sobol mpmc clement \
    --replicates 3 --total-seconds 86400

# E2: 2D Phase-2-only from Phase-1 outputs (you implement init_phase1)
python baseline_e1_e2.py --dim 2 --N 16 40 60 100 \
    --families phase1 \
    --replicates 3 --total-seconds 86400

# 3D baseline (same pattern)
python baseline_e1_e2.py --dim 3 --N 4 8 16 \
    --families fibonacci sobol mpmc clement \
    --replicates 3 --total-seconds 86400

You can also override maxiter manually:
    --maxiter 500

The script prints:
- Per-run summaries
- Global CSV lines of the form:
    dim,N,family,rep,success,nit,nfev,best_D_star
"""

from __future__ import annotations

import argparse
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple

import numpy as np
from numba import njit
from scipy.optimize import minimize
from scipy.stats import qmc

# =============================================================================
# Your exact star_discrepancy implementation (general D)  :contentReference[oaicite:2]{index=2}
# =============================================================================

@njit(cache=True)
def _calculate_single_box_discrepancy_numba(points_X_arg: np.ndarray,
                                            N_arg: int,
                                            D_arg: int,
                                            y_corner_arg: np.ndarray) -> float:
    """
    Calculates the local discrepancy for a single d-dimensional anchored box.
    Box is defined by [0, y_corner_arg[0]] x ... x [0, y_corner_arg[D-1]].
    """
    # Calculate volume of the box
    volume = 1.0
    for k_dim in range(D_arg):
        volume *= y_corner_arg[k_dim]

    # Count points within the box [0, y_corner_arg]
    count_in_box = 0
    count_on_line = 0
    for i_point in range(N_arg):  # Iterate through each point
        point_is_in_box = True
        point_is_on_line = False
        for k_dim in range(D_arg):  # Iterate through each dimension
            if points_X_arg[i_point, k_dim] > y_corner_arg[k_dim]:
                point_is_in_box = False
                break
            elif points_X_arg[i_point, k_dim] == y_corner_arg[k_dim]:
                point_is_on_line = True

        if point_is_in_box:
            count_in_box += 1
            if point_is_on_line:
                count_on_line += 1

    return max(abs(count_in_box / N_arg - volume),
               abs((count_in_box - count_on_line) / N_arg - volume))


def star_discrepancy(points_X: np.ndarray) -> float:
    """
    L-infinity star discrepancy of a point set P in [0,1]^D.

    NOTE: This returns the *discrepancy*, not the 1/(1+D*) "score" variant.
    """
    # Input validation and preparation
    if not isinstance(points_X, np.ndarray):
        points_X_np = np.array(points_X, dtype=np.float64)
    elif points_X.dtype != np.float64:  # Ensure float64 for Numba compatibility and precision
        points_X_np = points_X.astype(np.float64)
    else:
        points_X_np = points_X

    if points_X_np.ndim == 1:
        points_X_np = points_X_np.reshape(-1, 1)

    N, D = points_X_np.shape

    if N == 0:
        return 1.0

    points_X_clipped = np.clip(points_X_np, 0.0, 1.0)

    if not points_X_clipped.flags.c_contiguous:
        points_X_clipped = np.ascontiguousarray(points_X_clipped)

    # Build grid lines for each dimension from unique coordinates + 1.0
    grid_lines_per_dim: List[np.ndarray] = []
    for j in range(D):
        unique_coords_dim_j = np.unique(points_X_clipped[:, j])
        current_dim_grid_lines = np.union1d(
            unique_coords_dim_j,
            np.array([1.0], dtype=points_X_clipped.dtype)
        )
        grid_lines_per_dim.append(current_dim_grid_lines)

    max_discrepancy_val = 0.0
    y_corner_for_numba = np.empty(D, dtype=points_X_clipped.dtype)

    if not all(len(gl) > 0 for gl in grid_lines_per_dim):
        max_discrepancy_val = 0.0
    else:
        import itertools
        for y_corner_tuple in itertools.product(*grid_lines_per_dim):
            for i_val in range(D):
                y_corner_for_numba[i_val] = y_corner_tuple[i_val]

            local_discrepancy = _calculate_single_box_discrepancy_numba(
                points_X_clipped, N, D, y_corner_for_numba
            )

            if local_discrepancy > max_discrepancy_val:
                max_discrepancy_val = local_discrepancy

    return float(max_discrepancy_val)


# =============================================================================
# Initializers (2D / 3D) — no jitter
# =============================================================================

def init_fibonacci(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Simple Fibonacci / Kronecker lattice style construction for D=2 or D=3.

    For D=2:
        x_i = (i + 0.5) / N
        y_i = frac( (i + 0.5) * phi )

    For D=3:
        x_i = (i + 0.5) / N
        y_i = frac( (i + 0.5) * phi )
        z_i = frac( (i + 0.5) * (sqrt(2) - 1) )

    For D>3 (not needed here), we could extend with more incommensurate slopes.
    """
    if D not in (2, 3):
        raise ValueError(f"init_fibonacci only implemented for D in {{2,3}}, got D={D}")

    i = np.arange(N, dtype=np.float64)
    pts = np.zeros((N, D), dtype=np.float64)

    # First dim: simple stratification
    pts[:, 0] = (i + 0.5) / N

    phi = (np.sqrt(5.0) - 1.0) / 2.0  # ~0.618...

    if D >= 2:
        pts[:, 1] = np.mod((i + 0.5) * phi, 1.0)
    if D == 3:
        alpha2 = np.sqrt(2.0) - 1.0
        pts[:, 2] = np.mod((i + 0.5) * alpha2, 1.0)

    return np.clip(pts, 0.0, 1.0)


def init_sobol(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Sobol sequence baseline using scipy.stats.qmc.Sobol.
    """
    # Use seed if provided (for replicability across runs)
    seed = None if rng is None else int(rng.integers(0, 2**31 - 1))
    engine = qmc.Sobol(d=D, scramble=True, seed=seed)
    pts = engine.random(N)
    return np.clip(pts, 0.0, 1.0)


# --- Placeholders: you fill these in to read from your stored point sets. ---

def init_mpmc(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Placeholder initializer that should return an (N, D) array in [0,1]^D
    built from your GNN / MPMC outputs.

    Suggested pattern (pseudocode):
        path = f"/path/to/mpmc/N{N}_D{D}.npy"
        pts = np.load(path)
        assert pts.shape == (N, D)
        return pts

    Right now this raises, so you remember to implement it.
    """
    raise NotImplementedError("init_mpmc() is a placeholder. Please implement loading of MPMC point sets.")


def init_clement(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Placeholder initializer for Clément et al.'s best-known point sets.

    Suggested pattern:
        path = f"/path/to/clement/N{N}_D{D}.npy"
        pts = np.load(path)
        assert pts.shape == (N, D)
        return pts
    """
    raise NotImplementedError("init_clement() is a placeholder. Please implement loading of Clément et al. point sets.")


def init_phase1(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Placeholder initializer for using Phase-1 program outputs as seeds
    for Phase-2-only baselines (E2).

    Suggested pattern:
        # e.g. pick the best Phase-1 candidate for this N,D
        path = f"/path/to/phase1_outputs/N{N}_D{D}_best.npy"
        pts = np.load(path)
        assert pts.shape == (N, D)
        return pts
    """
    raise NotImplementedError("init_phase1() is a placeholder. Please implement loading of Phase-1 outputs.")


INIT_FAMILIES: Dict[str, Callable[..., np.ndarray]] = {
    "fibonacci": init_fibonacci,
    "sobol": init_sobol,
    "mpmc": init_mpmc,
    "clement": init_clement,
    "phase1": init_phase1,
}


# =============================================================================
# Time-based calibration of maxiter
# =============================================================================

def benchmark_star_discrepancy(N: int,
                               D: int,
                               num_samples: int = 10,
                               rng: np.random.Generator | None = None) -> float:
    """
    Measure average wall-clock time of star_discrepancy on random (N,D) points,
    ignoring the first call to avoid including Numba compilation.
    """
    if rng is None:
        rng = np.random.default_rng(12345)

    # Warm-up to trigger Numba compilation
    X0 = rng.random((N, D))
    _ = star_discrepancy(X0)

    times: List[float] = []
    for k in range(num_samples):
        X = rng.random((N, D))
        t0 = time.perf_counter()
        _ = star_discrepancy(X)
        t1 = time.perf_counter()
        times.append(t1 - t0)

    return float(np.mean(times))


def calibrate_maxiter(N_list: List[int],
                      D: int,
                      families: List[str],
                      replicates: int,
                      total_seconds: float = 24 * 3600.0,
                      calls_per_iter_estimate: float = 1.0,
                      num_benchmark_samples: int = 10) -> int:
    """
    Choose a global maxiter such that, under a simple cost model:

        total_time ≈ num_runs * maxiter * calls_per_iter_estimate * t_call_max

    does not exceed total_seconds.

    - t_call_max is estimated from benchmark_star_discrepancy for the largest cost case.
    - num_runs = len(N_list) * len(families) * replicates

    This is deliberately conservative: actual wall-clock will typically be lower
    because:
        * smaller N are cheaper,
        * SciPy overhead isn't included,
        * SLSQP may stop before maxiter.
    """
    rng = np.random.default_rng(2025)

    # Use the slowest (max) call time across all N to be safe
    t_call_max = 0.0
    for N in N_list:
        tN = benchmark_star_discrepancy(N, D, num_samples=num_benchmark_samples, rng=rng)
        t_call_max = max(t_call_max, tN)

    num_runs = len(N_list) * len(families) * replicates
    if t_call_max <= 0.0 or num_runs == 0:
        return 100  # fallback

    maxiter = int(total_seconds / (num_runs * t_call_max * calls_per_iter_estimate))
    if maxiter < 1:
        maxiter = 1

    print("=== Calibration summary ===")
    print(f"  dim D           = {D}")
    print(f"  N_list          = {N_list}")
    print(f"  families        = {families}")
    print(f"  replicates      = {replicates}")
    print(f"  total_seconds   = {total_seconds:.1f}")
    print(f"  t_call_max      = {t_call_max:.6f} s")
    print(f"  num_runs        = {num_runs}")
    print(f"  calls/iter est. = {calls_per_iter_estimate}")
    print(f"  -> maxiter      = {maxiter}")
    print("===========================")
    return maxiter


# =============================================================================
# Core optimization routine
# =============================================================================

@dataclass
class RunResult:
    dim: int
    N: int
    family: str
    rep: int
    success: bool
    nit: int
    nfev: int
    best_D_star: float
    x_opt: np.ndarray


def run_single_baseline(N: int,
                        D: int,
                        family: str,
                        rep: int,
                        maxiter: int,
                        ftol: float = 1e-15) -> RunResult:
    """
    Run a single SLSQP local optimization from a given family for (N,D).

    - Starts directly from the family-specific point set (no jitter).
    - Uses star_discrepancy(points) as the objective.
    - Enforces [0,1]^D bounds.
    """
    if family not in INIT_FAMILIES:
        raise ValueError(f"Unknown family '{family}'. Available: {list(INIT_FAMILIES.keys())}")

    rng = np.random.default_rng(10_000 + 100 * D + 10 * rep)  # deterministic but distinct

    init_fn = INIT_FAMILIES[family]
    pts0 = init_fn(N, D, rng=rng)
    if pts0.shape != (N, D):
        raise ValueError(f"Initializer '{family}' returned shape {pts0.shape}, expected {(N, D)}")

    x0 = pts0.reshape(-1)
    bounds = [(0.0, 1.0)] * x0.size

    call_counter = {"n": 0}

    def objective(x_flat: np.ndarray) -> float:
        call_counter["n"] += 1
        pts = x_flat.reshape(N, D)
        pts = np.clip(pts, 0.0, 1.0)
        return star_discrepancy(pts)

    res = minimize(
        objective,
        x0,
        method="SLSQP",
        bounds=bounds,
        options={"maxiter": int(maxiter), "ftol": float(ftol), "iprint": 0},
    )

    x_opt = np.clip(res.x.reshape(N, D), 0.0, 1.0)
    best_D_star = float(star_discrepancy(x_opt))

    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
          f"success={res.success}, nit={res.nit}, nfev={res.nfev}, "
          f"best D*={best_D_star:.6f}")

    return RunResult(
        dim=D,
        N=N,
        family=family,
        rep=rep,
        success=bool(res.success),
        nit=int(res.nit),
        nfev=int(res.nfev),
        best_D_star=best_D_star,
        x_opt=x_opt,
    )


# =============================================================================
# CLI harness
# =============================================================================

def main():
    ap = argparse.ArgumentParser(
        description="SciPy SLSQP baselines for E1/E2 (2D and 3D star discrepancy)"
    )
    ap.add_argument("--dim", type=int, required=True,
                    help="Dimension D (2 or 3). Run separately for 2D and 3D.")
    ap.add_argument("--N", type=int, nargs="+", required=True,
                    help="One or more N values, e.g., 16 40 60 100.")
    ap.add_argument("--families", type=str, nargs="+",
                    default=["fibonacci", "sobol"],
                    help=("Which initial families to use. "
                          "Options include: fibonacci sobol mpmc clement phase1"))
    ap.add_argument("--replicates", type=int, default=1,
                    help="Number of independent repetitions per (N, family).")
    ap.add_argument("--total-seconds", type=float, default=24 * 3600.0,
                    help="Total wall-clock budget (seconds) for calibration of maxiter.")
    ap.add_argument("--maxiter", type=int, default=None,
                    help="If provided, use this maxiter instead of time-based calibration.")
    ap.add_argument("--ftol", type=float, default=1e-15,
                    help="SLSQP ftol parameter.")
    ap.add_argument("--save-dir", type=str, default=None,
                    help="Optional directory to save final point sets as .npy.")
    args = ap.parse_args()

    D = int(args.dim)
    if D not in (2, 3):
        raise SystemExit(f"--dim must be 2 or 3, got {D}")

    N_list = [int(N) for N in args.N]
    families = list(args.families)
    replicates = int(args.replicates)

    # Calibrate maxiter if not provided
    if args.maxiter is None:
        maxiter = calibrate_maxiter(
            N_list=N_list,
            D=D,
            families=families,
            replicates=replicates,
            total_seconds=float(args.total_seconds),
            calls_per_iter_estimate=1.0,  # adjust if you want to be more conservative
            num_benchmark_samples=10,
        )
    else:
        maxiter = int(args.maxiter)
        print(f"Using user-specified maxiter={maxiter}")

    results: List[RunResult] = []

    for N in N_list:
        for family in families:
            for rep in range(replicates):
                try:
                    res = run_single_baseline(
                        N=N,
                        D=D,
                        family=family,
                        rep=rep,
                        maxiter=maxiter,
                        ftol=float(args.ftol),
                    )
                    results.append(res)

                    # Optional: save point sets
                    if args.save_dir is not None:
                        import os
                        os.makedirs(args.save_dir, exist_ok=True)
                        fname = f"D{D}_N{N}_{family}_rep{rep}.npy"
                        path = os.path.join(args.save_dir, fname)
                        np.save(path, res.x_opt)
                except NotImplementedError as e:
                    # Friendly message if a placeholder init is still unimplemented
                    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
                          f"SKIPPED due to NotImplementedError: {e}")
                except Exception as e:
                    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
                          f"FAILED with error: {e}")

    # Compact CSV-style summary (easy to grep / import)
    print("\n# dim,N,family,rep,success,nit,nfev,best_D_star")
    for r in results:
        print(f"{r.dim},{r.N},{r.family},{r.rep},"
              f"{int(r.success)},{r.nit},{r.nfev},{r.best_D_star:.10f}")


if __name__ == "__main__":
    main()
