import torch

def compute_invariants_displacements(x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4):
    f1 = z1 + z2 + z3 + z4
    f2 = y1 + y2 + y3 + y4
    f3 = x1 + x2 + x3 + x4
    f4 = z1**2 + z2**2 + z3**2 + z4**2
    f5 = y1*z1 + y2*z2 + y3*z3 + y4*z4
    f6 = x1*z1 + x2*z2 + x3*z3 + x4*z4
    f7 = y1**2 + y2**2 + y3**2 + y4**2
    f8 = x1*y1 + x2*y2 + x3*y3 + x4*y4
    f9 = x1**2 + x2**2 + x3**2 + x4**2
    f10 = z1**3 + z2**3 + z3**3 + z4**3
    f11 = y1*z1**2 + y2*z2**2 + y3*z3**2 + y4*z4**2
    f12 = x1*z1**2 + x2*z2**2 + x3*z3**2 + x4*z4**2
    f13 = y1**2*z1 + y2**2*z2 + y3**2*z3 + y4**2*z4
    f14 = x1*y1*z1 + x2*y2*z2 + x3*y3*z3 + x4*y4*z4
    f15 = x1**2*z1 + x2**2*z2 + x3**2*z3 + x4**2*z4
    f16 = y1**3 + y2**3 + y3**3 + y4**3
    f17 = x1*y1**2 + x2*y2**2 + x3*y3**2 + x4*y4**2
    f18 = x1**2*y1 + x2**2*y2 + x3**2*y3 + x4**2*y4
    f19 = x1**3 + x2**3 + x3**3 + x4**3
    f20 = z1**4 + z2**4 + z3**4 + z4**4
    f21 = y1*z1**3 + y2*z2**3 + y3*z3**3 + y4*z4**3
    f22 = x1*z1**3 + x2*z2**3 + x3*z3**3 + x4*z4**3
    f23 = y1**2*z1**2 + y2**2*z2**2 + y3**2*z3**2 + y4**2*z4**2
    f24 = x1*y1*z1**2 + x2*y2*z2**2 + x3*y3*z3**2 + x4*y4*z4**2
    f25 = x1**2*z1**2 + x2**2*z2**2 + x3**2*z3**2 + x4**2*z4**2
    f26 = y1**3*z1 + y2**3*z2 + y3**3*z3 + y4**3*z4
    f27 = x1*y1**2*z1 + x2*y2**2*z2 + x3*y3**2*z3 + x4*y4**2*z4
    f28 = x1**2*y1*z1 + x2**2*y2*z2 + x3**2*y3*z3 + x4**2*y4*z4
    f29 = x1**3*z1 + x2**3*z2 + x3**3*z3 + x4**3*z4
    f30 = y1**4 + y2**4 + y3**4 + y4**4
    f31 = x1*y1**3 + x2*y2**3 + x3*y3**3 + x4*y4**3
    f32 = x1**2*y1**2 + x2**2*y2**2 + x3**2*y3**2 + x4**2*y4**2
    f33 = x1**3*y1 + x2**3*y2 + x3**3*y3 + x4**3*y4
    f34 = x1**4 + x2**4 + x3**4 + x4**4

    return [f1,  f2,  f3,  f4,  f5,  f6,  f7,  f8,  f9,  f10,  f11,  f12,  f13,  f14,  f15,  f16,  f17,  f18,  f19,  f20,  f21,  f22,  f23,  f24,  f25,  f26,  f27,  f28,  f29,  f30,  f31,  f32,  f33,  f34]


def compute_invariants_displacements_wrapper(vectors: torch.Tensor) -> torch.Tensor:
    """
    Compute displacement invariants for a batch of 4×3 displacement vectors.

    Parameters
    ----------
    vectors : torch.Tensor
        Shape (batch_size, 4, 3).  The batch can reside on any device
        and use any floating dtype supported by PyTorch.

    Returns
    -------
    torch.Tensor
        Shape (batch_size, 34) — the 34 invariants for each batch item.
    """
    # Validate shape ---------------------------------------------------------
    if vectors.ndim != 3 or vectors.shape[1:] != (4, 3):
        raise ValueError(
            f"Expected tensor of shape (batch_size, 4, 3), got {vectors.shape}"
        )

    # Unpack coordinates -----------------------------------------------------
    x1, y1, z1 = vectors[:, 0, 0], vectors[:, 0, 1], vectors[:, 0, 2]
    x2, y2, z2 = vectors[:, 1, 0], vectors[:, 1, 1], vectors[:, 1, 2]
    x3, y3, z3 = vectors[:, 2, 0], vectors[:, 2, 1], vectors[:, 2, 2]
    x4, y4, z4 = vectors[:, 3, 0], vectors[:, 3, 1], vectors[:, 3, 2]

    # The invariants routine works with tensors just like with scalars
    results_list = compute_invariants_displacements(
        x1, y1, z1,
        x2, y2, z2,
        x3, y3, z3,
        x4, y4, z4,
    )  # → list of 34 tensors, each (batch_size,)

    # Stack into a single (batch_size, 34) tensor on the same device/dtype
    return torch.stack(results_list, dim=1)