"""
baseline_b1.py

Budget-matched Phase-2-only baseline for 2D star-discrepancy:
- No LLM constructive phase; only SLSQP local optimization.
- Cycles through initializations: J-LHS, Halton, Hammersley, Fibonacci, Uniform.
- Stops exactly when the number of star-discrepancy evaluations reaches B_N.

Usage examples
--------------
# Single N with a uniform budget of 25_000 calls, 10 replicates
python baseline_b1.py --N 100 --budget 25000 --replicates 10

# Multiple N with per-N budgets loaded from JSON file
python baseline_b1.py --N 30 40 50 60 100 --budget-per-n budgets.json --replicates 10

budgets.json (example)
{
  "30": 18000,
  "40": 22000,
  "50": 26000,
  "60": 30000,
  "100": 45000
}

Notes
-----
- Assumes 2D (d=2) as in the paper’s B1 ablation.
- Uses your exact 'star_discrepancy(points)' for the objective.
- Enforces [0,1]^2 bounds; objective raises when budget is exhausted.
- Logs best-so-far point set and discrepancy; returns robust aggregate stats.

"""

from __future__ import annotations
import argparse
import json
import math
import sys
import itertools
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple

import numpy as np
from scipy.optimize import minimize
from scipy.stats import qmc
from numba import njit

# --------------------------------------------------------------------------------------
# Import your exact star_discrepancy implementation for fairness (Numba-accelerated).
# Change the import below to match the module/file where your function resides.
# --------------------------------------------------------------------------------------
@njit(cache=True)
def _calculate_single_box_discrepancy_numba(points_X_arg: np.ndarray, 
                                           N_arg: int, 
                                           D_arg: int, 
                                           y_corner_arg: np.ndarray) -> float:
    """
    Calculates the local discrepancy for a single d-dimensional anchored box.
    Box is defined by [0, y_corner_arg[0]] x ... x [0, y_corner_arg[D-1]].
    """
    # Calculate volume of the box
    volume = 1.0
    for k_dim in range(D_arg):
        volume *= y_corner_arg[k_dim]

    # Count points within the box [0, y_corner_arg]
    # The original logic `points_X[None, :] <= y_corners[:, None, :]`
    # effectively means a point is counted if point_coord <= corner_coord for all dimensions.
    count_in_box = 0
    count_on_line = 0
    for i_point in range(N_arg): # Iterate through each point
        point_is_in_box = True
        point_is_on_line = False
        for k_dim in range(D_arg): # Iterate through each dimension for the current point
            if points_X_arg[i_point, k_dim] > y_corner_arg[k_dim]: # Point is outside this dimension
                point_is_in_box = False
                break
            elif points_X_arg[i_point, k_dim] == y_corner_arg[k_dim]:
                point_is_on_line = True
            
        if point_is_in_box:
            count_in_box += 1
            if point_is_on_line:
                count_on_line += 1

    return max(abs(count_in_box / N_arg - volume), abs((count_in_box - count_on_line) / N_arg - volume))

def star_discrepancy(points_X: np.ndarray) -> float:
    """
    Calculates a score based on the L-infinity star discrepancy of the point set P.
    Optimized using Numba for the core calculation loop.
    The score is 1 / (1 + max_discrepancy_val).
    """
    # Input validation and preparation
    if not isinstance(points_X, np.ndarray):
        points_X_np = np.array(points_X, dtype=np.float64)
    elif points_X.dtype != np.float64: # Ensure float64 for Numba compatibility and precision
        points_X_np = points_X.astype(np.float64)
    else:
        points_X_np = points_X

    if points_X_np.ndim == 1:
        points_X_np = points_X_np.reshape(-1, 1)
    
    N, D = points_X_np.shape

    if N == 0:
        return 1.0

    points_X_clipped = np.clip(points_X_np, 0.0, 1.0)
    
    if not points_X_clipped.flags.c_contiguous:
        points_X_clipped = np.ascontiguousarray(points_X_clipped)

    grid_lines_per_dim = []
    for j in range(D):
        unique_coords_dim_j = np.unique(points_X_clipped[:, j])
        current_dim_grid_lines = np.union1d(unique_coords_dim_j, 
                                            np.array([1.0], dtype=points_X_clipped.dtype))
        grid_lines_per_dim.append(current_dim_grid_lines)

    max_discrepancy_val = 0.0
    
    y_corner_for_numba = np.empty(D, dtype=points_X_clipped.dtype)

    if not all(len(gl) > 0 for gl in grid_lines_per_dim):
        max_discrepancy_val = 0.0
    else:
        for y_corner_tuple in itertools.product(*grid_lines_per_dim):
            for i_val in range(D):
                y_corner_for_numba[i_val] = y_corner_tuple[i_val]
            
            local_discrepancy = _calculate_single_box_discrepancy_numba(
                points_X_clipped, N, D, y_corner_for_numba
            )
            
            if local_discrepancy > max_discrepancy_val:
                max_discrepancy_val = local_discrepancy

    return max_discrepancy_val


# -----------------------------
# Initializers (2D)
# -----------------------------

def init_lhs_jittered(N: int, rng: np.random.Generator) -> np.ndarray:
    eng = qmc.LatinHypercube(d=2, seed=rng)
    return eng.random(N)

def init_halton(N: int, rng: np.random.Generator) -> np.ndarray:
    eng = qmc.Halton(d=2, scramble=True, seed=rng)
    return eng.random(N)

def init_hammersley(N: int, rng: np.random.Generator) -> np.ndarray:
    # 2D Hammersley: x = (i+0.5)/N; y = radical_inverse_base2(i+1)
    i = np.arange(N, dtype=np.int64)
    x = (i + 0.5) / N
    # radical inverse base 2
    def rev_bits(v: int) -> float:
        out = 0.0
        p = 0.5
        while v:
            out += (v & 1) * p
            v >>= 1
            p *= 0.5
        return out
    y = np.array([rev_bits(int(k + 1)) for k in i], dtype=np.float64)
    pts = np.column_stack([x, y])
    # Add a tiny random jitter to break symmetries
    pts += (rng.random(size=pts.shape) - 0.5) * (1.0 / (10 * N))
    return np.mod(pts, 1.0)

def init_fibonacci(N: int, rng: np.random.Generator) -> np.ndarray:
    # Shifted Fibonacci lattice (as in your Listing)
    phi = (np.sqrt(5.0) - 1.0) / 2.0
    i = np.arange(N, dtype=np.float64)
    x = (i + 0.5) / N
    y = np.mod((i * phi) + (0.5 / N), 1.0)
    pts = np.column_stack([x, y])
    # small jitter to diversify restarts
    pts += (rng.random(size=pts.shape) - 0.5) * (1.0 / (10 * N))
    return np.clip(pts, 0.0, 1.0)

def init_uniform(N: int, rng: np.random.Generator) -> np.ndarray:
    return rng.random((N, 2))


INIT_FAMILIES: Dict[str, Callable[[int, np.random.Generator], np.ndarray]] = {
    "lhs": init_lhs_jittered,
    "halton": init_halton,
    "hammersley": init_hammersley,
    "fibonacci": init_fibonacci,
    "uniform": init_uniform,
}


# -----------------------------
# Budgeted objective wrapper
# -----------------------------

class BudgetExhausted(Exception):
    pass

class BudgetedObjective:
    """
    Wraps star_discrepancy(points) so we can:
    - Count calls
    - Track best-so-far
    - Stop exactly at a given call budget by raising an exception
    """
    def __init__(self, N: int, budget_calls: int):
        self.N = N
        self.budget_calls = int(budget_calls)
        self.calls_used = 0
        self.best_val = float("inf")
        self.best_x = None  # flat

    def __call__(self, x_flat: np.ndarray) -> float:
        if self.calls_used >= self.budget_calls:
            raise BudgetExhausted()

        pts = x_flat.reshape(self.N, 2)
        # Bounds are enforced in the optimizer, but numerical drift can happen.
        pts = np.clip(pts, 0.0, 1.0)

        val = star_discrepancy(pts)
        self.calls_used += 1

        if val < self.best_val:
            self.best_val = float(val)
            self.best_x = np.ascontiguousarray(pts).ravel()

        if self.calls_used >= self.budget_calls:
            # Allow returning val, but signal to stop ASAP
            raise BudgetExhausted()

        return float(val)


# -----------------------------
# Core routine (Phase-2-only)
# -----------------------------

@dataclass
class RunConfig:
    N: int
    budget_calls: int
    maxiter: int = 30000
    ftol: float = 1e-15
    seed: int = 0
    families: Tuple[str, ...] = ("lhs", "halton", "hammersley", "fibonacci", "uniform")

def phase2_only_best(N: int, budget_calls: int, seed: int = 0,
                     families: Tuple[str, ...] = ("lhs", "halton", "hammersley", "fibonacci", "uniform"),
                     maxiter: int = 30000, ftol: float = 1e-15) -> Dict[str, object]:
    """
    Budget-matched Phase-2-only local optimization.
    Cycles through 'families' until the number of star_discrepancy calls reaches 'budget_calls'.

    Returns dict with:
      - best_discrepancy
      - best_points
      - calls_used
      - restarts
    """
    rng = np.random.default_rng(seed)
    bounds = [(0.0, 1.0)] * (N * 2)

    best_val_global = float("inf")
    best_pts_global = None
    calls_used_global = 0
    restarts = 0

    fam_cycle = itertools.cycle(families)

    while calls_used_global < budget_calls:
        fam = next(fam_cycle)

        # Create starting design
        x0 = INIT_FAMILIES[fam](N, rng).ravel()

        # Objective with remaining budget
        remaining = budget_calls - calls_used_global
        obj = BudgetedObjective(N=N, budget_calls=remaining)

        try:
            res = minimize(
                obj,
                x0,
                method="SLSQP",
                bounds=bounds,
                options={"maxiter": maxiter, "ftol": ftol, "iprint": 0},
            )
            # If it converged before exhausting budget, best is in obj anyway
        except BudgetExhausted:
            # graceful exit when budget hit
            pass

        calls_used_global += obj.calls_used
        restarts += 1

        # Update global best
        if obj.best_val < best_val_global:
            best_val_global = obj.best_val
            best_pts_global = obj.best_x.reshape(N, 2).copy()

    return {
        "best_discrepancy": float(best_val_global),
        "best_points": best_pts_global,
        "calls_used": int(calls_used_global),
        "restarts": int(restarts),
    }


# -----------------------------
# CLI & multi-replicate harness
# -----------------------------

def main():
    ap = argparse.ArgumentParser(description="Budget-matched Phase-2-only baseline for 2D star discrepancy")
    ap.add_argument("--N", type=int, nargs="+", required=True, help="One or more N values (e.g., 30 40 50 60 100)")
    ap.add_argument("--replicates", type=int, default=1, help="Number of independent repetitions per N")
    ap.add_argument("--seed0", type=int, default=12345, help="Base seed; replicate r uses seed0+r")
    ap.add_argument("--budget", type=int, default=None, help="Uniform budget (calls) for all N (if no --budget-per-n)")
    ap.add_argument("--budget-per-n", type=str, default=None,
                    help="JSON path or inline JSON mapping N->budget (strings or ints), e.g. '{\"30\":18000,...}'")
    ap.add_argument("--families", type=str, nargs="*", default=["lhs", "halton", "hammersley", "fibonacci", "uniform"],
                    help="Subset/order of initializers to cycle through")
    ap.add_argument("--maxiter", type=int, default=30000)
    ap.add_argument("--ftol", type=float, default=1e-15)
    args = ap.parse_args()

    # Resolve budgets
    budget_map: Dict[int, int] = {}
    if args.budget_per_n is not None:
        # allow file path or inline JSON
        try:
            # try file
            with open(args.budget_per_n, "r") as f:
                raw = json.load(f)
        except FileNotFoundError:
            # inline JSON
            raw = json.loads(args.budget_per_n)
        for k, v in raw.items():
            budget_map[int(k)] = int(v)

    results = []

    print("\n=== Phase-2-only baseline (budget-matched by star-discrepancy calls) ===")
    print(f"Families order: {args.families}\n")

    for N in args.N:
        B_N = budget_map.get(N, None)
        if B_N is None:
            if args.budget is None:
                raise SystemExit(f"No budget provided for N={N}. Use --budget or --budget-per-n.")
            B_N = int(args.budget)

        vals = []
        used = []
        reps = int(args.replicates)

        print(f"[N={N}] Budget B_N = {B_N} calls, Replicates = {reps}")
        for rep in range(reps):
            seed = args.seed0 + rep
            out = phase2_only_best(
                N=N,
                budget_calls=B_N,
                seed=seed,
                families=tuple(args.families),
                maxiter=args.maxiter,
                ftol=args.ftol,
            )
            vals.append(out["best_discrepancy"])
            used.append(out["calls_used"])
            print(f"  Rep {rep:02d}: best D* = {out['best_discrepancy']:.6f} (calls_used={out['calls_used']}, restarts={out['restarts']})")

        vals = np.array(vals, dtype=np.float64)
        used = np.array(used, dtype=np.int64)

        print(f"-> N={N}: median D* = {np.median(vals):.6f}, mean D* = {np.mean(vals):.6f}, "
              f"min D* = {np.min(vals):.6f}, max D* = {np.max(vals):.6f}")
        print(f"          calls_used: median={int(np.median(used))}, min={int(np.min(used))}, max={int(np.max(used))}\n")

        results.append((N, vals, used))

    # Optional: write a compact summary to stdout in CSV-ish lines
    print("N,rep,best_D_star,calls_used")
    for (N, vals, used) in results:
        for r, (v, u) in enumerate(zip(vals, used)):
            print(f"{N},{r},{v:.8f},{u}")

if __name__ == "__main__":
    main()
