
from __future__ import annotations

import argparse
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple
import itertools

import numpy as np
from numba import njit
from scipy.optimize import minimize
from scipy.stats import qmc
import os
import sys


@njit(cache=True)
def _calculate_single_box_discrepancy_numba(points_X_arg: np.ndarray, 
                                           N_arg: int, 
                                           D_arg: int, 
                                           y_corner_arg: np.ndarray) -> float:
    """
    Calculates the local discrepancy for a single d-dimensional anchored box.
    Box is defined by [0, y_corner_arg[0]] x ... x [0, y_corner_arg[D-1]].
    """
    # Calculate volume of the box
    volume = 1.0
    for k_dim in range(D_arg):
        volume *= y_corner_arg[k_dim]

    # Count points within the box [0, y_corner_arg]
    # The original logic `points_X[None, :] <= y_corners[:, None, :]`
    # effectively means a point is counted if point_coord <= corner_coord for all dimensions.
    count_in_box = 0
    count_on_line = 0
    for i_point in range(N_arg): # Iterate through each point
        point_is_in_box = True
        point_is_on_line = False
        for k_dim in range(D_arg): # Iterate through each dimension for the current point
            if points_X_arg[i_point, k_dim] > y_corner_arg[k_dim]: # Point is outside this dimension
                point_is_in_box = False
                break
            elif points_X_arg[i_point, k_dim] == y_corner_arg[k_dim]:
                point_is_on_line = True
            
        if point_is_in_box:
            count_in_box += 1
            if point_is_on_line:
                count_on_line += 1

    return max(abs(count_in_box / N_arg - volume), abs((count_in_box - count_on_line) / N_arg - volume))

def star_discrepancy(points_X: np.ndarray) -> float:
    """
    Calculates a score based on the L-infinity star discrepancy of the point set P.
    Optimized using Numba for the core calculation loop.
    The score is 1 / (1 + max_discrepancy_val).
    """
    # Input validation and preparation
    if not isinstance(points_X, np.ndarray):
        points_X_np = np.array(points_X, dtype=np.float64)
    elif points_X.dtype != np.float64: # Ensure float64 for Numba compatibility and precision
        points_X_np = points_X.astype(np.float64)
    else:
        points_X_np = points_X

    if points_X_np.ndim == 1:
        points_X_np = points_X_np.reshape(-1, 1)
    
    N, D = points_X_np.shape

    if N == 0:
        return 1.0

    points_X_clipped = np.clip(points_X_np, 0.0, 1.0)
    
    if not points_X_clipped.flags.c_contiguous:
        points_X_clipped = np.ascontiguousarray(points_X_clipped)

    grid_lines_per_dim = []
    for j in range(D):
        unique_coords_dim_j = np.unique(points_X_clipped[:, j])
        current_dim_grid_lines = np.union1d(unique_coords_dim_j, 
                                            np.array([1.0], dtype=points_X_clipped.dtype))
        grid_lines_per_dim.append(current_dim_grid_lines)

    max_discrepancy_val = 0.0
    
    y_corner_for_numba = np.empty(D, dtype=points_X_clipped.dtype)

    if not all(len(gl) > 0 for gl in grid_lines_per_dim):
        max_discrepancy_val = 0.0
    else:
        for y_corner_tuple in itertools.product(*grid_lines_per_dim):
            for i_val in range(D):
                y_corner_for_numba[i_val] = y_corner_tuple[i_val]
            
            local_discrepancy = _calculate_single_box_discrepancy_numba(
                points_X_clipped, N, D, y_corner_for_numba
            )
            
            if local_discrepancy > max_discrepancy_val:
                max_discrepancy_val = local_discrepancy

    return max_discrepancy_val


def init_fibonacci(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Simple Fibonacci / Kronecker lattice style construction for D=2 or D=3.

    For D=2:
        x_i = (i + 0.5) / N
        y_i = frac( (i + 0.5) * phi )

    For D=3:
        x_i = (i + 0.5) / N
        y_i = frac( (i + 0.5) * phi )
        z_i = frac( (i + 0.5) * (sqrt(2) - 1) )

    For D>3 (not needed here), we could extend with more incommensurate slopes.
    """
    if D not in (2, 3):
        raise ValueError(f"init_fibonacci only implemented for D in {{2,3}}, got D={D}")

    i = np.arange(N, dtype=np.float64)
    pts = np.zeros((N, D), dtype=np.float64)

    # First dim: simple stratification
    pts[:, 0] = (i + 0.5) / N

    phi = (np.sqrt(5.0) - 1.0) / 2.0  # ~0.618...

    if D >= 2:
        pts[:, 1] = np.mod((i + 0.5) * phi, 1.0)
    if D == 3:
        alpha2 = np.sqrt(2.0) - 1.0
        pts[:, 2] = np.mod((i + 0.5) * alpha2, 1.0)

    return np.clip(pts, 0.0, 1.0)


def init_sobol(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Sobol sequence baseline using scipy.stats.qmc.Sobol.
    """
    # Use seed if provided (for replicability across runs)
    seed = None if rng is None else int(rng.integers(0, 2**31 - 1))
    engine = qmc.Sobol(d=D, scramble=True, seed=seed)
    pts = engine.random(N)
    return np.clip(pts, 0.0, 1.0)


def init_clement(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Placeholder initializer for Clément et al.'s best-known point sets.

    Suggested pattern:
        path = f"/path/to/clement/N{N}_D{D}.npy"
        pts = np.load(path)
        assert pts.shape == (N, D)
        return pts
    """
    path_to_points = f"openevolve-ablation/clement/clement_{N:03d}.txt"
    return np.loadtxt(path_to_points, dtype=float)


def init_phase1(N: int, D: int, rng: np.random.Generator | None = None) -> np.ndarray:
    """
    Placeholder initializer for using Phase-1 program outputs as seeds
    for Phase-2-only baselines (E2).

    Suggested pattern:
        # e.g. pick the best Phase-1 candidate for this N,D
        path = f"/path/to/phase1_outputs/N{N}_D{D}_best.npy"
        pts = np.load(path)
        assert pts.shape == (N, D)
        return pts
    """
    path_to_points = f"openevolve-ablation/openevolve_constructions/openevolve-direct-{N:04d}.txt"
    return np.loadtxt(path_to_points, dtype=float, delimiter=",")


INIT_FAMILIES: Dict[str, Callable[..., np.ndarray]] = {
    "fibonacci": init_fibonacci,
    "sobol": init_sobol,
    "mpmc": init_mpmc,
    "clement": init_clement,
    "phase1": init_phase1,
}


def benchmark_star_discrepancy(N: int,
                               D: int,
                               num_samples: int = 10,
                               rng: np.random.Generator | None = None) -> float:
    """
    Measure average wall-clock time of star_discrepancy on random (N,D) points,
    ignoring the first call to avoid including Numba compilation.
    """
    if rng is None:
        rng = np.random.default_rng(12345)

    X0 = rng.random((N, D))
    _ = star_discrepancy(X0)

    times: List[float] = []
    for k in range(num_samples):
        X = rng.random((N, D))
        t0 = time.perf_counter()
        _ = star_discrepancy(X)
        t1 = time.perf_counter()
        times.append(t1 - t0)

    return float(np.mean(times))


def calibrate_maxiter(N_list: List[int],
                      D: int,
                      families: List[str],
                      replicates: int,
                      total_seconds: float = 96 * 3600.0,
                      calls_per_iter_estimate: float = 1.0,
                      num_benchmark_samples: int = 10) -> int:
    rng = np.random.default_rng(2025)

    # Use the slowest (max) call time across all N to be safe
    t_call_max = 0.0
    for N in N_list:
        tN = benchmark_star_discrepancy(N, D, num_samples=num_benchmark_samples, rng=rng)
        t_call_max = max(t_call_max, tN)

    num_runs = len(N_list) * len(families) * replicates
    if t_call_max <= 0.0 or num_runs == 0:
        return 100  # fallback

    maxiter = int(total_seconds / (num_runs * t_call_max * calls_per_iter_estimate))
    if maxiter < 1:
        maxiter = 1

    print("=== Calibration summary ===")
    print(f"  dim D           = {D}")
    print(f"  N_list          = {N_list}")
    print(f"  families        = {families}")
    print(f"  replicates      = {replicates}")
    print(f"  total_seconds   = {total_seconds:.1f}")
    print(f"  t_call_max      = {t_call_max:.6f} s")
    print(f"  num_runs        = {num_runs}")
    print(f"  calls/iter est. = {calls_per_iter_estimate}")
    print(f"  -> maxiter      = {maxiter}")
    print("===========================")
    return maxiter


# =============================================================================
# Core optimization routine
# =============================================================================

@dataclass
class RunResult:
    dim: int
    N: int
    family: str
    rep: int
    success: bool
    nit: int
    nfev: int
    best_D_star: float
    x_opt: np.ndarray


def run_single_baseline(N: int,
                        D: int,
                        family: str,
                        rep: int,
                        maxiter: int,
                        ftol: float = 1e-15) -> RunResult:
    """
    Run a single SLSQP local optimization from a given family for (N,D).
    """
    if family not in INIT_FAMILIES:
        raise ValueError(f"Unknown family '{family}'. Available: {list(INIT_FAMILIES.keys())}")

    rng = np.random.default_rng(10_000 + 100 * D + 10 * rep)  # deterministic but distinct

    init_fn = INIT_FAMILIES[family]
    pts0 = init_fn(N, D, rng=rng)
    if pts0.shape != (N, D):
        raise ValueError(f"Initializer '{family}' returned shape {pts0.shape}, expected {(N, D)}")

    x0 = pts0.reshape(-1)
    bounds = [(0.0, 1.0)] * x0.size

    call_counter = {"n": 0}

    def objective(x_flat: np.ndarray) -> float:
        call_counter["n"] += 1
        pts = x_flat.reshape(N, D)
        pts = np.clip(pts, 0.0, 1.0)
        return star_discrepancy(pts)

    res = minimize(
        objective,
        x0,
        method="SLSQP",
        bounds=bounds,
        options={"maxiter": int(maxiter), "ftol": float(ftol), "iprint": 0},
    )

    x_opt = np.clip(res.x.reshape(N, D), 0.0, 1.0)
    best_D_star = float(star_discrepancy(x_opt))

    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
          f"success={res.success}, nit={res.nit}, nfev={res.nfev}, "
          f"best D*={best_D_star:.6f}")

    return RunResult(
        dim=D,
        N=N,
        family=family,
        rep=rep,
        success=bool(res.success),
        nit=int(res.nit),
        nfev=int(res.nfev),
        best_D_star=best_D_star,
        x_opt=x_opt,
    )


def main():
    save_dir = "openevolve-ablation/outputs"
    params = np.loadtxt("openevolve-ablation/params.txt", dtype=str)
    slurm_idx = int(os.getenv("SGE_TASK_ID") ) - 1
    param = params[slurm_idx]
    D = 2
    replicates = 1
    family, N = param.split("-")

    N_list = [int(N)]
    families = [family]
    total_seconds = 96 * 3600
    ftol = 1e-15
    maxiter = None

    # Calibrate maxiter if not provided
    if maxiter is None:
        maxiter = calibrate_maxiter(
            N_list=N_list,
            D=D,
            families=families,
            replicates=replicates,
            total_seconds=float(total_seconds),
            calls_per_iter_estimate=2*(int(N) * D + 1),  # adjust if you want to be more conservative
            num_benchmark_samples=20,
        )
    else:
        maxiter = int(maxiter)
        print(f"Using user-specified maxiter={maxiter}")

    results: List[RunResult] = []

    for N in N_list:
        for family in families:
            for rep in range(replicates):
                try:
                    res = run_single_baseline(
                        N=N,
                        D=D,
                        family=family,
                        rep=rep,
                        maxiter=maxiter,
                        ftol=float(ftol),
                    )
                    results.append(res)

                    # Optional: save point sets
                    if save_dir is not None:
                        os.makedirs(save_dir, exist_ok=True)
                        fname = f"D{D}_N{N}_{family}_rep{rep}.npy"
                        path = os.path.join(save_dir, fname)
                        np.save(path, res.x_opt)
                except NotImplementedError as e:
                    # Friendly message if a placeholder init is still unimplemented
                    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
                          f"SKIPPED due to NotImplementedError: {e}")
                except Exception as e:
                    print(f"[D={D}, N={N}, family={family}, rep={rep}] "
                          f"FAILED with error: {e}")

    # Compact CSV-style summary (easy to grep / import)
    print("\n# dim,N,family,rep,success,nit,nfev,best_D_star")
    for r in results:
        print(f"{r.dim},{r.N},{r.family},{r.rep},"
              f"{int(r.success)},{r.nit},{r.nfev},{r.best_D_star:.10f}")


if __name__ == "__main__":
    main()
