import numpy as np

def aggregate(
    times: np.ndarray,
    losses: np.ndarray,
    *,
    num_pts: int = 200,
    error: str = "iqr",
    grid: str = "linear",
):
    """
    Aggregate irregular loss‑vs‑time traces stored in *column‑major* format.

    Parameters
    ----------
    times   : (steps, repeats) array
        Wall‑clock timestamps (NaN where a step is absent).
    losses  : (steps, repeats) array
        Loss values aligned with `times` (NaN where absent).
    num_pts : int
        Number of points in the common time grid.
    error   : {"iqr", "std", "sem"}
        Spread measure for error band.
    grid    : {"linear", "log"}
        Spacing of the common grid.

    Returns
    -------
    t_grid, mean_curve, err_lo, err_hi
        1‑D arrays of length `num_pts`.
    """

    # --- 1.  Build a global time grid ---------------------------------------
    #   For each repeat pick its *last* non‑NaN timestamp
    last_t = np.nanmax(times, axis=0)  # shape: (repeats,)
    t_max = np.nanmin(last_t)  # ensure every repeat covers grid

    if grid == "linear":
        t_grid = np.linspace(0.0, t_max, num_pts)
    elif grid == "log":
        eps = np.finfo(float).eps
        t_grid = np.geomspace(max(eps, t_grid[0]), t_max, num_pts)
    else:
        raise ValueError("grid must be 'linear' or 'log'")

    # --- 2.  Interpolate each repeat onto the grid ---------------------------
    curves = np.empty((times.shape[1], num_pts))
    curves.fill(np.nan)

    for r in range(times.shape[1]):  # iterate over repeats
        mask = ~np.isnan(times[:, r]) & ~np.isnan(losses[:, r])
        if np.count_nonzero(mask) < 2:
            continue  # skip traces with <2 valid points

        t_r = times[mask, r]
        l_r = losses[mask, r]

        # Ensure strictly increasing time for interpolation
        ix = np.argsort(t_r)
        t_r, l_r = t_r[ix], l_r[ix]

        curves[r] = np.interp(t_grid, t_r, l_r, left=np.nan, right=np.nan)

    # --- 3.  Column‑wise statistics, ignoring NaNs ---------------------------
    mean_curve = np.nanmean(curves, axis=0)

    if error == "std":
        spread = np.nanstd(curves, axis=0)
        err_lo, err_hi = mean_curve - spread, mean_curve + spread

    elif error == "sem":
        spread = np.nanstd(curves, axis=0) / np.sqrt(np.sum(~np.isnan(curves), axis=0))
        err_lo, err_hi = mean_curve - spread, mean_curve + spread

    elif error == "iqr":
        q25 = np.nanpercentile(curves, 25, axis=0)
        q75 = np.nanpercentile(curves, 75, axis=0)
        err_lo, err_hi = q25, q75

    else:
        raise ValueError("error must be 'iqr', 'std', or 'sem'")
    
    return t_grid, mean_curve, err_lo, err_hi