import torch


def _ensure_tensor_2d(x):
    t = torch.tensor(x)
    if t.ndim == 1:
        # Allow empty lists to become shape [0, D] later in checks
        return t.unsqueeze(-1)
    return t


def collate_stack_same_shape(batch, x_key: str = "x", y_key: str = "y"):
    """Strict collate: stacks items with identical shapes.

    Expects each item to be a dict with x: [N_C, D], y: [N_C].
    Returns x: [B, N_C, D], y: [B, N_C, 1].
    """
    xs = [torch.tensor(ex[x_key]) for ex in batch]  # [N_C, D]
    ys = [torch.tensor(ex[y_key]).unsqueeze(-1) for ex in batch]  # [N_C, 1]
    s0x, s0y = xs[0].shape, ys[0].shape
    for x in xs:
        if x.shape != s0x:
            raise RuntimeError(f"Mixed x shapes in batch. First {s0x}, got {x.shape}")
    for y in ys:
        if y.shape != s0y:
            raise RuntimeError(f"Mixed y shapes in batch. First {s0y}, got {y.shape}")
    return {x_key: torch.stack(xs, 0), y_key: torch.stack(ys, 0)}


def collate_full_same_shape(batch):
    """Strict collate for full tabular records with context, buffer, and targets.

    Expects items with keys: x:[Nc,D], y:[Nc], xb:[Nb,D], yb:[Nb], xt:[Nt,D], yt:[Nt].
    Returns dict of tensors:
      x:[B,Nc,D], y:[B,Nc,1], xb:[B,Nb,D], yb:[B,Nb,1], xt:[B,Nt,D], yt:[B,Nt,1].
    All shapes are strictly validated to match across the batch.
    """
    # Convert to tensors and add y last-dim
    xs = [torch.tensor(ex["x"]) for ex in batch]
    ys = [torch.tensor(ex["y"]).unsqueeze(-1) for ex in batch]
    xbs = [torch.tensor(ex["xb"]) for ex in batch]
    ybs = [torch.tensor(ex["yb"]).unsqueeze(-1) for ex in batch]
    xts = [torch.tensor(ex["xt"]) for ex in batch]
    yts = [torch.tensor(ex["yt"]).unsqueeze(-1) for ex in batch]

    # Validate shapes are identical across batch
    s_x = xs[0].shape
    s_y = ys[0].shape
    s_xb = xbs[0].shape
    s_yb = ybs[0].shape
    s_xt = xts[0].shape
    s_yt = yts[0].shape

    for x in xs:
        if x.shape != s_x:
            raise RuntimeError(f"Mixed x shapes: first {s_x}, got {x.shape}")
    for y in ys:
        if y.shape != s_y:
            raise RuntimeError(f"Mixed y shapes: first {s_y}, got {y.shape}")
    for xb in xbs:
        if xb.shape != s_xb:
            raise RuntimeError(f"Mixed xb shapes: first {s_xb}, got {xb.shape}")
    for yb in ybs:
        if yb.shape != s_yb:
            raise RuntimeError(f"Mixed yb shapes: first {s_yb}, got {yb.shape}")
    for xt in xts:
        if xt.shape != s_xt:
            raise RuntimeError(f"Mixed xt shapes: first {s_xt}, got {xt.shape}")
    for yt in yts:
        if yt.shape != s_yt:
            raise RuntimeError(f"Mixed yt shapes: first {s_yt}, got {yt.shape}")

    return {
        "x": torch.stack(xs, 0),
        "y": torch.stack(ys, 0),
        "xb": torch.stack(xbs, 0),
        "yb": torch.stack(ybs, 0),
        "xt": torch.stack(xts, 0),
        "yt": torch.stack(yts, 0),
    }
