import numpy as np
from scipy.stats import qmc
from scipy.stats import norm
import ctypes
from typing import List, Dict, Any

from scipy.stats import wilcoxon, false_discovery_control
import os
import tempfile
import subprocess
import re
from pathlib import Path


# ---------------------------------------------------------------------------
# LatNetBuilder helpers: rank-1 lattice and digital net
# ---------------------------------------------------------------------------

LATNETBUILDER_EXE = "latnetsoft/bin/latnetbuilder"  # Adjust if not on PATH
LATNET_TMP_ROOT = None  # Set to a directory if you don't want tempfile.gettempdir()

_latnet_lattice_generator_cache: dict[tuple[int, int], np.ndarray] = {}
_latnet_digital_net_points_cache: dict[tuple[int, int], np.ndarray] = {}
_latnet_sobol_params_cache: dict[int, List[Dict[str, Any]]] = {}


def _parse_latnetbuilder_sobol_direction_numbers(output_path: str) -> List[List[int]]:
    path = Path(output_path)
    if not path.exists():
        raise FileNotFoundError(f"LatNetBuilder output file not found: {path}")

    with path.open("r") as f:
        lines = f.readlines()

    # Find the "Sobol Digital Net - Direction numbers" header line
    start_idx = None
    for i, line in enumerate(lines):
        lower = line.lower()
        if "sobol digital net" in lower and "direction numbers" in lower:
            start_idx = i + 1
            break

    if start_idx is None:
        snippet = "".join(lines[:40])
        raise ValueError(
            f"Could not find 'Sobol Digital Net - Direction numbers' header in {path}.\n"
            f"First lines:\n{'-'*60}\n{snippet}\n{'-'*60}"
        )

    rows: List[List[int]] = []
    for line in lines[start_idx:]:
        stripped = line.strip()
        if not stripped:
            continue
        # Stop at Merit / ELAPSED or other footer lines
        if stripped.lower().startswith("merit") or "elapsed" in stripped.lower():
            break
        # Extract all integers from the line
        vals = [int(x) for x in re.findall(r"-?\d+", stripped)]
        if vals:
            rows.append(vals)

    if not rows:
        snippet = "".join(lines[start_idx:start_idx+20])
        raise ValueError(
            f"No direction-number rows found after header in {path}.\n"
            f"Context:\n{'-'*60}\n{snippet}\n{'-'*60}"
        )

    return rows


def construct_latnetbuilder_sobol_params(
    dim: int,
    power_of_n: int,
    latnetbuilder_exe: str = LATNETBUILDER_EXE,
    tmp_root: str | None = LATNET_TMP_ROOT,
) -> List[Dict[str, Any]]:
    global _latnet_sobol_params_cache

    key = int(dim)
    if key in _latnet_sobol_params_cache:
        return _latnet_sobol_params_cache[key]

    if tmp_root is None:
        tmp_root = tempfile.gettempdir()
    out_dir = os.path.join(tmp_root, f"latnet_sobol_d{dim}")
    os.makedirs(out_dir, exist_ok=True)

    # You can tweak this command; this matches what you were using
    # (and clearly produces the "Sobol Digital Net - Direction numbers" output).
    cmd = [
        latnetbuilder_exe,
        "-t", "net",
        "-c", "sobol",
        "-s", f"2^{power_of_n}",                 # total points; large enough to define dir. numbers
        "-d", str(dim),
        "-e", "random-CBC:2000",
        "-f", "projdep:t-value",
        "-q", "inf",
        "-w", "order-dependent:0:0,1,1",
        "-o", out_dir,
    ]

    print(f"[LatNetBuilder] Sobol digital net (direction numbers): {' '.join(cmd)}")
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(
            f"LatNetBuilder failed when constructing Sobol digital net parameters "
            f"for dim={dim}.\n"
            f"Command: {' '.join(cmd)}\n"
            f"stdout:\n{e.stdout.decode(errors='ignore')}\n"
            f"stderr:\n{e.stderr.decode(errors='ignore')}"
        )

    output_path = os.path.join(out_dir, "output.txt")
    if not os.path.exists(output_path):
        raise FileNotFoundError(
            f"LatNetBuilder did not produce output.txt at {output_path}. "
            "Check the command and LatNetBuilder installation."
        )

    rows = _parse_latnetbuilder_sobol_direction_numbers(output_path)
    if len(rows) < dim:
        raise ValueError(
            f"LatNetBuilder returned only {len(rows)} direction-number rows for dim={dim}."
        )

    # Your construct_sobol_sequence() returns params for dimensions 2..dim
    base_params = construct_sobol_sequence()
    if len(base_params) != dim - 1:
        raise ValueError(
            f"construct_sobol_sequence() returned {len(base_params)} entries "
            f"but expected dim-1={dim-1}."
        )

    # rows[0] is for dimension 1 (we don't use it because your C++ code
    # only expects dimensions 2..dim in input_params).
    new_params: List[Dict[str, Any]] = []
    for j in range(dim - 1):
        # dimension index is 2..dim; take row j+1
        row_mi = rows[j + 1]
        entry = dict(base_params[j])   # copy Joe–Kuo entry
        entry["s"] = len(row_mi)
        entry["m_i"] = row_mi
        new_params.append(entry)

    _latnet_sobol_params_cache[key] = new_params
    return new_params


def _parse_latnetbuilder_generating_vector(output_path: str, dim: int) -> np.ndarray:
    path = Path(output_path)
    if not path.exists():
        raise FileNotFoundError(f"LatNetBuilder output file not found: {path}")

    with path.open("r") as f:
        lines = f.readlines()

    def ints_from_line(line: str) -> List[int]:
        return [int(x) for x in re.findall(r"-?\d+", line)]

    # ------------------------------------------------------------------
    # Case 1: Standard "# lattice" format
    # ------------------------------------------------------------------
    header_idx = None
    for i, line in enumerate(lines):
        if line.strip().lower().startswith("# lattice"):
            header_idx = i
            break

    if header_idx is not None:
        non_comment: List[str] = []
        for line in lines[header_idx + 1:]:
            stripped = line.strip()
            if not stripped:
                continue
            if stripped.startswith("#"):
                continue
            non_comment.append(line)

        try:
            if len(non_comment) >= 2:
                s_vals = ints_from_line(non_comment[0])
                if s_vals:
                    s = s_vals[0]
                    # second line is n; ignored
                    if s < dim:
                        raise ValueError(
                            f"LatNetBuilder lattice file {output_path} has only {s} dimensions, requested {dim}"
                        )
                    vec_vals: List[int] = []
                    for line in non_comment[2:2 + s]:
                        vec_vals.extend(ints_from_line(line))
                    if len(vec_vals) >= dim:
                        return np.array(vec_vals[:dim], dtype=np.int64)
        except Exception:
            pass  # fall through to other strategies

    # ------------------------------------------------------------------
    # Case 2: Lines mentioning "generating vector"/"best lattice"/etc.
    # ------------------------------------------------------------------
    trigger_keywords = (
        "generating vector",
        "best lattice",
        "generating values",
        "vector a",
        "generator",
        "generating vector a",
    )

    for i, line in enumerate(lines):
        low = line.lower()
        if any(k in low for k in trigger_keywords):
            vec_vals: List[int] = []
            for j in range(i, len(lines)):
                ln = lines[j]
                stripped = ln.strip()
                if not stripped:
                    if vec_vals:
                        break
                    else:
                        continue
                if stripped.startswith("#"):
                    if vec_vals:
                        break
                    else:
                        continue
                vec_vals.extend(ints_from_line(ln))
                if len(vec_vals) >= dim:
                    return np.array(vec_vals[:dim], dtype=np.int64)

    # ------------------------------------------------------------------
    # Case 3: dash-separated vector like "1-45-101-...".
    # ------------------------------------------------------------------
    dash_pattern = re.compile(r"\d+(?:-\d+)+")
    for line in lines:
        stripped = line.strip()
        if not stripped or stripped.startswith("#"):
            continue
        m = dash_pattern.search(stripped)
        if m:
            parts = m.group(0).split("-")
            if len(parts) >= dim:
                return np.array([int(p) for p in parts[:dim]], dtype=np.int64)

    # ------------------------------------------------------------------
    # Case 4: any non-comment line with >= dim integers
    # ------------------------------------------------------------------
    for line in lines:
        stripped = line.strip()
        if not stripped or stripped.startswith("#"):
            continue
        vals = ints_from_line(stripped)
        if len(vals) >= dim:
            return np.array(vals[:dim], dtype=np.int64)

    snippet = "".join(lines[:40])
    raise ValueError(
        f"Could not find a generating vector of length {dim} in {path}.\n"
        f"{'-'*60}\n{snippet}\n{'-'*60}\n"
        "Please inspect the pattern above and adjust the parser if needed."
    )


def construct_latnetbuilder_lattice_generator(
    n_points: int,
    dim: int,
    latnetbuilder_exe: str = LATNETBUILDER_EXE,
    tmp_root: str | None = LATNET_TMP_ROOT,
) -> np.ndarray:
    """
    Call LatNetBuilder to construct a rank-1 lattice rule in `dim` dimensions
    with `n_points` points, and return its generating vector a ∈ {0,...,n_points-1}^dim.

    Uses fast-CBC with CU:P2 and product weights as a sensible default.
    Results are cached per (n_points, dim).
    """
    key = (int(n_points), int(dim))
    if key in _latnet_lattice_generator_cache:
        return _latnet_lattice_generator_cache[key]

    if tmp_root is None:
        tmp_root = tempfile.gettempdir()
    out_dir = os.path.join(tmp_root, f"latnet_n{n_points}_d{dim}")
    os.makedirs(out_dir, exist_ok=True)

    cmd = [
        latnetbuilder_exe,
        "-t", "lattice",
        "-c", "ordinary",
        "-s", str(n_points),
        "-d", str(dim),
        "-e", "fast-CBC",
        "-f", "CU:P2",
        "-q", "2",
        "-w", "product:0.1",
        "-o", out_dir,
    ]

    print(f"[LatNetBuilder] Rank-1 lattice: {' '.join(cmd)}")
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(
            f"LatNetBuilder failed for lattice n={n_points}, d={dim}.\n"
            f"Command: {' '.join(cmd)}\n"
            f"stdout:\n{e.stdout.decode(errors='ignore')}\n"
            f"stderr:\n{e.stderr.decode(errors='ignore')}"
        )

    output_path = os.path.join(out_dir, "output.txt")
    if not os.path.exists(output_path):
        raise FileNotFoundError(
            f"LatNetBuilder did not produce output.txt at {output_path}. "
            "Check the command and LatNetBuilder installation."
        )

    gen_vec = _parse_latnetbuilder_generating_vector(output_path, dim)
    _latnet_lattice_generator_cache[key] = gen_vec
    return gen_vec


def generate_rank1_lattice_points(generator: np.ndarray, n_points: int) -> np.ndarray:
    """
    Generate rank-1 lattice points P_n = { (i * a mod n) / n } in [0,1)^d
    from generating vector `generator` of shape (d,).

    No random shift here; use a shift separately for RQMC.
    """
    generator = np.asarray(generator, dtype=np.int64)
    d = generator.shape[0]
    i = np.arange(n_points, dtype=np.int64).reshape(-1, 1)  # (n_points, 1)
    lattice_int = (i * generator.reshape(1, d)) % n_points
    return lattice_int / float(n_points)


# ------------------ Digital net parsing and generation ---------------------


def _parse_latnetbuilder_dnet_matrices(output_path: str) -> tuple[np.ndarray, int, int]:
    """
    Parse digital net generating matrices in base 2 from LatNetBuilder's output.txt
    in the qmc-points '# dnet' format.

    Returns:
        C_mats: np.ndarray of shape (s, k) with integer columns (encoded in base 2)
        k_val:  number of columns (log2(n_points))
        r_val:  number of rows (bits of precision)
    """
    path = Path(output_path)
    if not path.exists():
        raise FileNotFoundError(f"LatNetBuilder output file not found: {path}")

    with path.open("r") as f:
        lines = f.readlines()

    def ints_from_line(line: str) -> List[int]:
        return [int(x) for x in re.findall(r"-?\d+", line)]

    # Find '# dnet' header
    header_idx = None
    for i, line in enumerate(lines):
        if line.strip().lower().startswith("# dnet"):
            header_idx = i
            break

    if header_idx is None:
        snippet = "".join(lines[:40])
        raise ValueError(
            "No '# dnet' header found in LatNetBuilder output.\n"
            f"First lines:\n{'-'*60}\n{snippet}\n{'-'*60}\n"
            "Check output format or add an explicit dnet output option."
        )

    # Collect non-comment, non-empty lines after header
    non_comment: List[str] = []
    for line in lines[header_idx + 1:]:
        stripped = line.strip()
        if not stripped:
            continue
        if stripped.startswith("#"):
            continue
        non_comment.append(line)

    if len(non_comment) < 4:
        raise ValueError(
            f"Not enough lines after '# dnet' header in {output_path} "
            f"to read base, dimension, k, r."
        )

    base_val = ints_from_line(non_comment[0])[0]
    s_val = ints_from_line(non_comment[1])[0]
    k_val = ints_from_line(non_comment[2])[0]
    r_val = ints_from_line(non_comment[3])[0]

    if base_val != 2:
        raise ValueError(f"Digital net base {base_val} != 2 is not supported.")

    # Next s_val lines: generating matrices (each line has k_val integers)
    C_list: List[List[int]] = []
    gen_lines = non_comment[4:]
    for line in gen_lines:
        if len(C_list) >= s_val:
            break
        vals = ints_from_line(line)
        if not vals:
            continue
        if len(vals) < k_val:
            raise ValueError(
                f"Line '{line.strip()}' has fewer than k={k_val} integers when parsing dnet matrices."
            )
        C_list.append(vals[:k_val])

    if len(C_list) < s_val:
        raise ValueError(
            f"Expected {s_val} generating-matrix rows, got {len(C_list)} in {output_path}."
        )

    C_mats = np.array(C_list, dtype=np.int64)  # shape (s_val, k_val)
    return C_mats, k_val, r_val


def generate_digital_net_points(
    C_mats: np.ndarray, n_points: int, dim: int, r: int
) -> np.ndarray:
    C_mats = np.asarray(C_mats, dtype=np.int64)
    s, k = C_mats.shape

    if s < dim:
        raise ValueError(
            f"dnet matrices have only {s} dimensions, but dim={dim} requested."
        )

    if n_points != (1 << k):
        raise ValueError(
            f"n_points={n_points} is not 2^k with k={k} from dnet file."
        )

    indices = np.arange(n_points, dtype=np.int64)
    points_int = np.zeros((n_points, dim), dtype=np.int64)

    for j in range(dim):
        C_row = C_mats[j]  # shape (k,)
        coord = np.zeros(n_points, dtype=np.int64)
        for bit_idx in range(k):
            bitmask = (indices >> bit_idx) & 1  # 0/1
            contrib = bitmask * C_row[bit_idx]
            coord ^= contrib  # XOR integer columns
        points_int[:, j] = coord

    return points_int / float(1 << r)


def construct_latnetbuilder_digital_net_points(
    n_points: int,
    dim: int,
    k_exponent: int,
    latnetbuilder_exe: str = LATNETBUILDER_EXE,
    tmp_root: str | None = LATNET_TMP_ROOT,
) -> np.ndarray:
    """
    Call LatNetBuilder to construct a digital net (Sobol' construction) in `dim`
    dimensions with n_points = 2^k_exponent points, and return the point set.
    """
    key = (int(n_points), int(dim))
    if key in _latnet_digital_net_points_cache:
        return _latnet_digital_net_points_cache[key]

    if tmp_root is None:
        tmp_root = tempfile.gettempdir()
    out_dir = os.path.join(tmp_root, f"latnet_dnet_n{n_points}_d{dim}")
    os.makedirs(out_dir, exist_ok=True)

    cmd = [
        latnetbuilder_exe,
        "-t", "net",
        "-c", "sobol",
        "-s", f"2^{k_exponent}",
        "-d", str(dim),
        "-e", "random-CBC:2000",
        "-f", "projdep:t-value",
        "-q", "inf",
        "-w", "order-dependent:0:1,1,1",
        "-o", out_dir,
    ]

    print(f"[LatNetBuilder] Digital net: {' '.join(cmd)}")
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(
            f"LatNetBuilder failed for digital net n={n_points}, d={dim}.\n"
            f"Command: {' '.join(cmd)}\n"
            f"stdout:\n{e.stdout.decode(errors='ignore')}\n"
            f"stderr:\n{e.stderr.decode(errors='ignore')}"
        )

    output_path = os.path.join(out_dir, "output.txt")
    if not os.path.exists(output_path):
        raise FileNotFoundError(
            f"LatNetBuilder did not produce output.txt at {output_path}. "
            "Check the command and LatNetBuilder installation."
        )

    C_mats, k_val, r_val = _parse_latnetbuilder_dnet_matrices(output_path)
    if (1 << k_val) != n_points:
        raise ValueError(
            f"dnet file implies 2^{k_val}={1<<k_val} points but n_points={n_points} requested."
        )

    base_points = generate_digital_net_points(C_mats, n_points, dim, r_val)
    _latnet_digital_net_points_cache[key] = base_points
    return base_points


# ---------------------------------------------------------------------------
# Sobol C++ helper (from your original code)
# ---------------------------------------------------------------------------

class DimensionParameters(ctypes.Structure):
    _fields_ = [
        ("s", ctypes.c_int),
        ("a", ctypes.c_uint32),
        ("m_i", ctypes.c_uint32 * 30)  # Array of uint32_t
    ]


def get_sobol_points_cpp(input_params, n_points, n_dimensions, ltm, shifts):
    """
    Generates Sobol points using the C++ shared library.
    """
    lib_path = "/projects/bdln/asadikov/openevolve-old/examples/qmc/sobol_generator.so"
    sobol_lib = ctypes.CDLL(lib_path)

    sobol_lib.generate_sobol_points.argtypes = [
        ctypes.c_int,                                 # n_points
        ctypes.c_int,                                 # n_dimensions
        ctypes.POINTER(DimensionParameters),          # input_sobol_params
        ctypes.POINTER(ctypes.c_double),              # output_points
        ctypes.POINTER(ctypes.c_uint32),              # ltm_elements_flat
        ctypes.POINTER(ctypes.c_uint32)               # digital_shifts
    ]
    sobol_lib.generate_sobol_points.restype = None

    ParamsArrayType = DimensionParameters * (n_dimensions - 1)
    ctypes_params_array = ParamsArrayType()

    LTMElementsFlatTypePy = ctypes.c_uint32 * ltm.size
    ctypes_ltm = LTMElementsFlatTypePy.from_buffer(ltm)

    DigitalShiftsArrayTypePy = ctypes.c_uint32 * n_dimensions
    ctypes_shift = DigitalShiftsArrayTypePy.from_buffer(shifts)

    for i, py_param in enumerate(input_params):
        if not isinstance(py_param, dict):
            raise TypeError(f"Parameter for dimension {i+1} must be a dictionary.")
        s_val = py_param.get('s')
        a_val = py_param.get('a')
        m_i_list = py_param.get('m_i')

        ctypes_params_array[i].s = s_val
        ctypes_params_array[i].a = ctypes.c_uint32(a_val)
        for j in range(s_val):
            ctypes_params_array[i].m_i[j] = ctypes.c_uint32(m_i_list[j])
        for j in range(s_val, 30):
            ctypes_params_array[i].m_i[j] = 0

    OutputArrayType = ctypes.c_double * (n_points * n_dimensions)
    ctypes_output_points = OutputArrayType()

    sobol_lib.generate_sobol_points(
        n_points,
        n_dimensions,
        ctypes_params_array,
        ctypes_output_points,
        ctypes_ltm,
        ctypes_shift
    )
    np_output_points = np.ctypeslib.as_array(ctypes_output_points)
    np_output_points = np_output_points.reshape((n_points, n_dimensions))
    return np_output_points


def construct_sobol_sequence():
    """
    Joe–Kuo-like Sobol' direction numbers (your existing sequence).
    """
    params = [
        {'s': 1, 'a': 0, 'm_i': [1]},  # Dimension 2
        {'s': 2, 'a': 1, 'm_i': [1, 3]},  # Dimension 3
        {'s': 3, 'a': 1, 'm_i': [1, 3, 5]},  # Dimension 4
        {'s': 3, 'a': 2, 'm_i': [1, 3, 7]},  # Dimension 5
        {'s': 4, 'a': 1, 'm_i': [1, 1, 3, 7]},  # Dimension 6
        {'s': 4, 'a': 4, 'm_i': [1, 3, 5, 13]},  # Dimension 7
        {'s': 5, 'a': 2, 'm_i': [1, 1, 5, 5, 17]},  # Dimension 8
        {'s': 5, 'a': 4, 'm_i': [1, 1, 5, 5, 5]},  # Dimension 9
        {'s': 5, 'a': 7, 'm_i': [1, 1, 7, 11, 19]},  # Dimension 10
        {'s': 5, 'a': 11, 'm_i': [1, 1, 5, 1, 1]},  # Dimension 11
        {'s': 5, 'a': 13, 'm_i': [1, 1, 1, 3, 11]},  # Dimension 12
        {'s': 5, 'a': 14, 'm_i': [1, 3, 5, 5, 31]},  # Dimension 13
        {'s': 6, 'a': 1, 'm_i': [1, 3, 3, 9, 7, 49]},  # Dimension 14
        {'s': 6, 'a': 13, 'm_i': [1, 1, 1, 15, 21, 21]},  # Dimension 15
        {'s': 6, 'a': 16, 'm_i': [1, 3, 1, 13, 27, 49]},  # Dimension 16
        {'s': 6, 'a': 19, 'm_i': [1, 1, 1, 15, 7, 5]},  # Dimension 17
        {'s': 6, 'a': 22, 'm_i': [1, 3, 1, 15, 13, 25]},  # Dimension 18
        {'s': 6, 'a': 25, 'm_i': [1, 1, 5, 5, 19, 61]},  # Dimension 19
        {'s': 7, 'a': 1, 'm_i': [1, 3, 7, 11, 23, 15, 103]},  # Dimension 20
        {'s': 7, 'a': 4, 'm_i': [1, 3, 7, 13, 13, 15, 69]},  # Dimension 21
        {'s': 7, 'a': 7, 'm_i': [1, 1, 3, 13, 7, 35, 63]},  # Dimension 22
        {'s': 7, 'a': 8, 'm_i': [1, 3, 5, 9, 1, 25, 53]},  # Dimension 23
        {'s': 7, 'a': 14, 'm_i': [1, 3, 1, 13, 9, 35, 107]},  # Dimension 24
        {'s': 7, 'a': 19, 'm_i': [1, 3, 1, 5, 27, 61, 31]},  # Dimension 25
        {'s': 7, 'a': 21, 'm_i': [1, 1, 5, 11, 19, 41, 61]},  # Dimension 26
        {'s': 7, 'a': 28, 'm_i': [1, 3, 5, 3, 3, 13, 69]},  # Dimension 27
        {'s': 7, 'a': 31, 'm_i': [1, 1, 7, 13, 1, 19, 1]},  # Dimension 28
        {'s': 7, 'a': 32, 'm_i': [1, 3, 7, 5, 13, 19, 59]},  # Dimension 29
        {'s': 7, 'a': 37, 'm_i': [1, 1, 3, 9, 25, 29, 41]},  # Dimension 30
        {'s': 7, 'a': 41, 'm_i': [1, 3, 5, 13, 23, 1, 55]},  # Dimension 31
        {'s': 7, 'a': 42, 'm_i': [1, 3, 7, 3, 13, 59, 17]}   # Dimension 32
    ]
    return params


# ---------------------------------------------------------------------------
# Option pricer + experiment
# ---------------------------------------------------------------------------

def asian_option_pricer(s0, k, t, r, sigma, d, qmc_points):
    dt = t / d
    normal_samples = norm.ppf(qmc_points)
    brownian_path_increments = np.sqrt(dt) * normal_samples
    brownian_paths = np.cumsum(brownian_path_increments, axis=1)
    time_steps = np.linspace(dt, t, d)
    asset_paths = s0 * np.exp((r - 0.5 * sigma**2) * time_steps + sigma * brownian_paths)
    average_asset_prices = np.mean(asset_paths, axis=1)
    payoffs = np.maximum(average_asset_prices - k, 0)
    option_price = np.exp(-r * t) * np.mean(payoffs)
    return option_price


asian_options = {
    "og":   {"S0": 50.0, "K_strike": 45.0, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.3, "D_dims": 32, "C0_true": 7.06451424679549},
    "otm":  {"S0": 50.0, "K_strike": 60.0, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.3, "D_dims": 32, "C0_true": 1.0161335048477829},
    "atm":  {"S0": 50.0, "K_strike": 52.5, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.3, "D_dims": 32, "C0_true": 2.9753224833133496},
    "itm":  {"S0": 50.0, "K_strike": 40.0, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.3, "D_dims": 32, "C0_true": 11.015933988360171},
    "hvol": {"S0": 50.0, "K_strike": 52.5, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.6, "D_dims": 32, "C0_true": 6.4274784214688045},
    "lvol": {"S0": 50.0, "K_strike": 52.5, "T_exp": 1.0, "R_rate": 0.05, "SIGMA": 0.1, "D_dims": 32, "C0_true": 0.6931527993078156}
}


# LatNetBuilder Sobol-direction-number baseline (digital net)
LATNET_DIM = 32   # all your asian_options use D_dims = 32
# latnet_sobol_params = construct_latnetbuilder_sobol_params(LATNET_DIM)
sobol_params = construct_sobol_sequence()


# p / w aggregation kept for lattice baseline (as in your script)
p_lnet = {}
w_lnet = {}
p_dnet = {}
w_dnet = {}
p_sobol = {}
w_sobol = {}


for option_name, option_params in asian_options.items():
    S0 = option_params["S0"]
    K_strike = option_params["K_strike"]
    T_exp = option_params["T_exp"]
    R_rate = option_params["R_rate"]
    SIGMA = option_params["SIGMA"]
    D_dims = option_params["D_dims"]
    C0_true = option_params["C0_true"]
    N = 10000

    p_lnet[option_name] = []
    w_lnet[option_name] = []
    p_dnet[option_name] = []
    w_dnet[option_name] = []
    # p_sobol[option_name] = []
    # w_sobol[option_name] = []

    for n_pts_base in range(5, 14):  # 2^5=32 up to 2^13=8192
        n_pts = 2 ** n_pts_base
        print(f"\n=== {option_name} | Number of points: {n_pts} ===")

        # ------------------ Sobol baseline (SciPy) --------------------
        # est_price = np.zeros(N)
        # est_mse = np.zeros(N)
        # for i in range(N):
        #     qmc_engine = qmc.Sobol(d=D_dims, scramble=True, seed=i)
        #     sobol_points_set = qmc_engine.random_base2(m=n_pts_base)
        #     estimated_price_sobol = asian_option_pricer(
        #         S0, K_strike, T_exp, R_rate, SIGMA, D_dims, sobol_points_set
        #     )
        #     est_price[i] = estimated_price_sobol
        #     est_mse[i] = (estimated_price_sobol - C0_true) ** 2

        # print(
        #     f"{option_name} SciPy Sobol  Bias: {(np.mean(est_price) - C0_true) ** 2:.3e} | "
        #     f"Variance: {np.var(est_price):.3e} | MSE: {np.mean(est_mse):.3e}"
        # )

        # ------------------ Rank-1 lattice baseline -------------------
        lat_generator = construct_latnetbuilder_lattice_generator(
            n_points=n_pts,
            dim=D_dims,
            latnetbuilder_exe=LATNETBUILDER_EXE,
            tmp_root=LATNET_TMP_ROOT,
        )
        base_lattice_points = generate_rank1_lattice_points(lat_generator, n_pts)

        est_lat_price = np.zeros(N)
        est_lat_mse = np.zeros(N)
        for i in range(N):
            rng_lat = np.random.default_rng(seed=i)
            shift = rng_lat.random(D_dims)
            shifted_points = (base_lattice_points + shift) % 1.0

            estimated_price_lat = asian_option_pricer(
                S0, K_strike, T_exp, R_rate, SIGMA, D_dims, shifted_points
            )
            est_lat_price[i] = estimated_price_lat
            est_lat_mse[i] = (estimated_price_lat - C0_true) ** 2

        print(
            f"{option_name} LatNet lattice Bias: {(np.mean(est_lat_price) - C0_true) ** 2:.3e} | "
            f"Variance: {np.var(est_lat_price):.3e} | MSE: {np.mean(est_lat_mse):.3e}"
        )

        # ------------------ LatNetBuilder Sobol digital-net baseline -------------------
        est_dnet_price = np.zeros(N)
        est_dnet_mse = np.zeros(N)
        latnet_sobol_params = construct_latnetbuilder_sobol_params(LATNET_DIM, n_pts_base)

        for i in range(N):
            rng_d = np.random.default_rng(seed=i)

            # random digital shifts + LTM, same pattern as you use for OpenEvolve
            shift_bits_py = rng_d.integers(2, size=(D_dims, 30), dtype=np.uint32)
            digital_shifts_py = np.dot(
                shift_bits_py, 2 ** np.arange(30, dtype=np.uint32)
            )

            ltm_elements_for_tril = rng_d.integers(
                2, size=(D_dims, 30, 30), dtype=np.uint32
            )
            ltm_py = np.tril(ltm_elements_for_tril)
            for d_idx_ltm in range(D_dims):
                for i_idx_ltm in range(30):
                    ltm_py[d_idx_ltm, i_idx_ltm, i_idx_ltm] = 1

            dnet_points_set = get_sobol_points_cpp(
                latnet_sobol_params, n_pts, D_dims, ltm_py, digital_shifts_py
            )
            estimated_price_dnet = asian_option_pricer(
                S0, K_strike, T_exp, R_rate, SIGMA, D_dims, dnet_points_set
            )
            est_dnet_price[i] = estimated_price_dnet
            est_dnet_mse[i] = (estimated_price_dnet - C0_true) ** 2
        print(
            f"{option_name} LatNet Sobol Digital Net  Bias: {(np.mean(est_dnet_price) - C0_true) ** 2:.3e} | "
            f"Variance: {np.var(est_dnet_price):.3e} | MSE: {np.mean(est_dnet_mse):.3e}"
        )

        # ------------------ OpenEvolve Sobol baseline (C++ generator) -------------------
        est_oe_price = np.zeros(N)
        est_oe_mse = np.zeros(N)
        for i in range(N):
            rng_oe = np.random.default_rng(seed=i)

            # digital shifts for each dimension (same pattern you used before)
            shift_bits_py = rng_oe.integers(2, size=(D_dims, 30), dtype=np.uint32)
            digital_shifts_py = np.dot(
                shift_bits_py, 2 ** np.arange(30, dtype=np.uint32)
            )

            # lower-triangular matrices for scrambling
            ltm_elements_for_tril = rng_oe.integers(
                2, size=(D_dims, 30, 30), dtype=np.uint32
            )
            ltm_py = np.tril(ltm_elements_for_tril)
            for d_idx_ltm in range(D_dims):
                for i_idx_ltm in range(30):
                    ltm_py[d_idx_ltm, i_idx_ltm, i_idx_ltm] = 1

            oe_points_set = get_sobol_points_cpp(
                sobol_params, n_pts, D_dims, ltm_py, digital_shifts_py
            )
            estimated_price_oe = asian_option_pricer(
                S0, K_strike, T_exp, R_rate, SIGMA, D_dims, oe_points_set
            )
            est_oe_price[i] = estimated_price_oe
            est_oe_mse[i] = (estimated_price_oe - C0_true) ** 2

        print(
            f"{option_name} OpenEvolve Sobol Bias: "
            f"{(np.mean(est_oe_price) - C0_true) ** 2:.3e} | "
            f"Variance: {np.var(est_oe_price):.3e} | MSE: {np.mean(est_oe_mse):.3e}"
        )


        # Wilcoxon for lattice vs Sobol (MSE – fixed to MSE)
        res_lat = wilcoxon(est_oe_mse, est_lat_mse, alternative='less')
        print(
            f"LatNet lattice vs Sobol MSE Wilcoxon: "
            f"stat={res_lat.statistic}, p-value={res_lat.pvalue:.3e}"
        )
        p_lnet[option_name].append(res_lat.pvalue)
        w_lnet[option_name].append(res_lat.statistic)

        # Wilcoxon: LatNet Sobol vs SciPy Sobol (MSE)
        res_dnet = wilcoxon(est_oe_mse, est_dnet_mse, alternative='less')
        print(
            f"LatNet Sobol (digital net) vs SciPy Sobol MSE Wilcoxon: "
            f"stat={res_dnet.statistic}, p-value={res_dnet.pvalue:.3e}"
        )
        p_dnet[option_name].append(res_dnet.pvalue)
        w_dnet[option_name].append(res_dnet.statistic)


    w_lnet[option_name] = np.array(w_lnet[option_name])
    p_lnet[option_name] = np.array(p_lnet[option_name])
    w_dnet[option_name] = np.array(w_dnet[option_name])
    p_dnet[option_name] = np.array(p_dnet[option_name])

# print("\nP-values from Wilcoxon test (LatNet lattice vs Sobol):")
# for option_name, p_values in p.items():
#     print(f"{option_name}: {p_values}")

# print("\nWilcoxon test statistics (LatNet lattice vs Sobol):")
# for option_name, w_values in w.items():
#     print(f"{option_name}: {w_values}")

W_lnet = np.concatenate([w_lnet[option_name].reshape(-1, 1) for option_name in asian_options.keys()], axis=1)
P_lnet = np.concatenate([p_lnet[option_name].reshape(-1, 1) for option_name in asian_options.keys()], axis=1)
W_dnet = np.concatenate([w_dnet[option_name].reshape(-1, 1) for option_name in asian_options.keys()], axis=1)
P_dnet = np.concatenate([p_dnet[option_name].reshape(-1, 1) for option_name in asian_options.keys()], axis=1)

for i in range(P_lnet.shape[0]):
    print(f"\nNumber of points: {2 ** (i + 5)}")
    print("P-values (Lattice Net vs Sobol):", P_lnet[i])
    print("Wilcoxon statistics:", W_lnet[i])
    print("FDR-corrected P-values:", false_discovery_control(P_lnet[i]))

for i in range(P_dnet.shape[0]):
    print(f"\nNumber of points: {2 ** (i + 5)}")
    print("P-values (Digital Net vs Sobol):", P_dnet[i])
    print("Wilcoxon statistics:", W_dnet[i])
    print("FDR-corrected P-values:", false_discovery_control(P_dnet[i]))