from BACKEND import cp, sp

# Written with significant assistance by ChatGPT
# Alternative constructions for the temporal parameters in the hidden layers
# Actually found to perform worse than the simple construction from the paper

def build_hierarchical_alpha_then_fill(k, T, alpha=1.1,
                                      L_phi=2.0,
                                      overlap=0.75,
                                      sigma_min=None,
                                      sigma_max=None):
    """
    Parameters
    ----------
    k : int
        total number of atoms requested
    T : float
        domain length
    alpha : float > 1
        refinement factor between levels (alpha=2 -> dyadic)
    L_phi : float
        support length of phi in argument units (phi supported on (-1,1) => L_phi=2)
    overlap : float in (0,1]
        stride = overlap * sigma is the desired spacing relation for fill atoms
    sigma_min, sigma_max : optional floats
        min/max allowed scale for clipping

    Returns
    -------
    tau_cp, sigma_cp : cupy arrays of length exactly k
        centers and scales for atoms
    """
    if k <= 0:
        raise ValueError("k must be positive integer")
    if T <= 0:
        raise ValueError("T must be positive")
    if alpha <= 1.0:
        raise ValueError("alpha must be > 1")

    small_eps = 1e-12
    if sigma_max is None:
        sigma_max = T / 2.0
    if sigma_min is None:
        sigma_min = max(T / (8.0 * max(1, k)), T * 1e-12)

    tau_pad = sigma_min

    # Precompute n_j = max(1, round(alpha^j)) and find largest J with sum(n_j) <= k
    n_list = []
    cumulative = 0
    j = 0
    # loop until adding next level would exceed k OR until some safe max iterations
    max_levels = 200
    while True:
        # compute n_j approximately alpha^j
        n_j = max(1, int(cp.ceil(alpha ** j)))
        n_list.append(n_j)
        cumulative += n_j
        if cumulative > k:
            # remove last and stop
            cumulative -= n_j
            n_list.pop()
            break
        j += 1
        if j >= max_levels:
            break

    # If no levels were added (k small), ensure at least level 0 exists
    if len(n_list) == 0:
        n_list = [1]
        cumulative = 1

    M = int(cumulative)          # total atoms in hierarchical complete part
    J = len(n_list)

    taus = []
    sigs = []

    # hierarchical complete part
    for j, n_j in enumerate(n_list):
        # choose sigma_j so that each atom covers T/n_j in support length
        sigma_j = float(T / (L_phi * n_j))
        sigma_j = max(sigma_min, min(sigma_j, sigma_max))
        # centers = midpoints of partition into n_j pieces
        for m in range(n_j):
            tau_jm = (m + 0.5) * (T / n_j)
            taus.append(float(tau_jm))
            sigs.append(float(sigma_j))

    # fill remainder uniformly in expanded interval
    r = k - M
    if r > 0:
        left = -float(tau_pad)
        right = float(T + tau_pad)
        interval_len = right - left
        if r == 1:
            centers_fill = [0.5 * (left + right)]
        else:
            centers_fill = cp.linspace(left, right, r).tolist()
        stride = interval_len / float(r) if r > 0 else interval_len
        sigma_fill = float(stride / max(overlap, small_eps))
        sigma_fill = max(sigma_min, min(sigma_fill, sigma_max))
        for c in centers_fill:
            taus.append(float(c))
            sigs.append(float(sigma_fill))

    # convert to CuPy arrays and ensure exact length k
    tau_cp = cp.array(taus, dtype=float)
    sigma_cp = cp.array(sigs, dtype=float)
    # safety trim/pad
    if tau_cp.size > k:
        tau_cp = tau_cp[:k]
        sigma_cp = sigma_cp[:k]
    elif tau_cp.size < k:
        pad_needed = k - int(tau_cp.size)
        pad_cent = 0.5 * T
        pad_sigma = float(sigma_cp[0]) if sigma_cp.size > 0 else float(sigma_min)
        extra_taus = cp.full((pad_needed,), pad_cent, dtype=float)
        extra_sigs = cp.full((pad_needed,), pad_sigma, dtype=float)
        tau_cp = cp.concatenate([tau_cp, extra_taus])
        sigma_cp = cp.concatenate([sigma_cp, extra_sigs])

    return tau_cp, sigma_cp


def build_n0_base_then_fill_overlap_frac(k, T,
                                         n_0=1,
                                         base=2.0,
                                         pad=0.02,
                                         overlap_frac=0.5,
                                         L_phi=2.0,
                                         sigma_min=None,
                                         sigma_max=None,
                                         max_levels=200):
    """
    Build hierarchical atoms with n_ell = n_0 * round(base**ell) at level ell covering
    [-pad, T+pad] with support overlap equal to `overlap_frac` fraction of each support.
    Fill remaining slots uniformly. Return (tau_cp, sigma_cp) as CuPy arrays length k.

    Parameters
    ----------
    - k : int > 0
    - T : float > 0
    - n_0 : int >= 1
    - base : float > 1
    - pad : float >= 0
    - overlap_frac : float in [0, 1)   # fraction of support that adjacent atoms overlap
    - L_phi : float (support length of phi in argument units, default 2 for phi supported on (-1,1))
    - sigma_min, sigma_max : optional floats to clip sigma values
    - max_levels : safety cap on number of hierarchical levels

    Returns
    -------
    tau_cp, sigma_cp : cupy arrays, length exactly k
      Hierarchical atoms listed level-by-level (coarse->fine), then fill atoms.
    """
    # ---- input validation ----
    if not (isinstance(k, int) and k > 0):
        raise ValueError("k must be a positive integer")
    if T <= 0:
        raise ValueError("T must be positive")
    if not (isinstance(n_0, int) and n_0 >= 1):
        raise ValueError("n_0 must be a positive integer")
    if base <= 1.0:
        raise ValueError("base must be > 1")
    if pad < 0:
        raise ValueError("pad must be >= 0")
    if not (0.0 <= overlap_frac < 1.0):
        raise ValueError("overlap_frac must satisfy 0 <= overlap_frac < 1")
    if L_phi <= 0:
        raise ValueError("L_phi must be positive")

    small_eps = 1e-12
    if sigma_max is None:
        sigma_max = T / 2.0
    if sigma_min is None:
        sigma_min = max(T / (8.0 * max(1, k)), T * 1e-12)

    L_int = float(T + 2.0 * pad)   # expanded interval length
    taus_list = []
    sigs_list = []

    cumulative = 0
    ell = 0

    # Build hierarchical levels until adding next would exceed k or max_levels reached
    while ell < max_levels:
        n_factor = max(1, int(round(base ** ell)))
        n_ell = n_0 * n_factor
        if cumulative + n_ell > k:
            break

        # center spacing
        d_ell = L_int / float(n_ell)   # spacing between centers

        # support length S_ell so overlap_frac = p satisfies: d_ell = (1-p)*S_ell
        p = float(overlap_frac)
        denom = (1.0 - p)
        if denom <= 0.0:
            denom = small_eps
        S_ell = d_ell / denom
        sigma_ell = float(S_ell / L_phi)

        # clip sigma
        sigma_ell = max(sigma_min, min(sigma_ell, sigma_max))

        # centers: midpoints of partition of expanded interval [-pad, T+pad]
        if n_ell == 1:
            centers = [float(-pad + 0.5 * L_int)]
        else:
            left_center = -pad + 0.5 * d_ell
            right_center = -pad + (n_ell - 0.5) * d_ell
            centers = cp.linspace(left_center, right_center, n_ell).tolist()

        taus_list.extend([float(c) for c in centers])
        sigs_list.extend([float(sigma_ell)] * n_ell)

        cumulative += n_ell
        ell += 1

    M = cumulative
    r = k - M

    # Fill remainder uniformly on expanded interval, with same overlap fraction
    if r > 0:
        d_fill = float(L_int) / float(r)
        p = float(overlap_frac)
        denom = (1.0 - p)
        if denom <= 0.0:
            denom = small_eps
        S_fill = d_fill / denom
        sigma_fill = float(S_fill / L_phi)
        sigma_fill = max(sigma_min, min(sigma_fill, sigma_max))

        if r == 1:
            centers_fill = [float(-pad + 0.5 * L_int)]
        else:
            left_center = -pad + 0.5 * d_fill
            right_center = -pad + (r - 0.5) * d_fill
            centers_fill = cp.linspace(left_center, right_center, r).tolist()

        taus_list.extend([float(c) for c in centers_fill])
        sigs_list.extend([float(sigma_fill)] * r)

    # Convert to CuPy arrays and ensure exact length k
    tau_cp = cp.array(taus_list, dtype=float)
    sigma_cp = cp.array(sigs_list, dtype=float)

    if tau_cp.size > k:
        tau_cp = tau_cp[:k]
        sigma_cp = sigma_cp[:k]
    elif tau_cp.size < k:
        pad_needed = k - int(tau_cp.size)
        pad_cent = float(-pad + 0.5 * L_int)
        pad_sigma = float(sigma_cp[0]) if sigma_cp.size > 0 else float(sigma_min)
        extra_taus = cp.full((pad_needed,), pad_cent, dtype=float)
        extra_sigs = cp.full((pad_needed,), pad_sigma, dtype=float)
        tau_cp = cp.concatenate([tau_cp, extra_taus])
        sigma_cp = cp.concatenate([sigma_cp, extra_sigs])

    return tau_cp, sigma_cp
