import torch
from torch import Tensor


def cubic_1d_power_series(t: Tensor) -> Tensor:
    x = torch.linalg.vander(t, N=4)
    return x


def cubic_1d_power_series_jac(t: Tensor) -> Tensor:
    v = torch.linalg.vander(t, N=3)
    factors = torch.tensor([1.0, 2.0, 3.0], device=t.device, dtype=t.dtype)
    jac = v * factors.unsqueeze(0)
    zero_col = torch.zeros((t.size(0), 1), device=t.device, dtype=t.dtype)
    jac = torch.cat((zero_col, jac), dim=1)
    return jac


def cubic_3d_power_series(u: Tensor) -> Tensor:
    x = cubic_1d_power_series(u[:, 0])
    y = cubic_1d_power_series(u[:, 1])
    z = cubic_1d_power_series(u[:, 2])
    z = z[:, :, None, None]
    x = x[:, None, :, None]
    y = y[:, None, None, :]
    power_series = (x * y * z).view(u.size(0), -1)  # Shape: (batch_size, 64)
    return power_series


def cubic_3d_power_series_jac_and_value(u: Tensor) -> Tensor:
    x = cubic_1d_power_series(u[:, 0])
    dx = cubic_1d_power_series_jac(u[:, 0])
    y = cubic_1d_power_series(u[:, 1])
    dy = cubic_1d_power_series_jac(u[:, 1])
    z = cubic_1d_power_series(u[:, 2])
    dz = cubic_1d_power_series_jac(u[:, 2])
    z = z[:, :, None, None]
    dz = dz[:, :, None, None]
    x = x[:, None, :, None]
    dx = dx[:, None, :, None]
    y = y[:, None, None, :]
    dy = dy[:, None, None, :]
    dx_series = (dx * y * z).view(u.size(0), -1)  # (batch_size, 64)
    dy_series = (x * dy * z).view(u.size(0), -1)  # (batch_size, 64)
    dz_series = (x * y * dz).view(u.size(0), -1)  # (batch_size, 64)
    series = (x * y * z).view(u.size(0), -1)  # (batch_size, 64)
    power_series = torch.stack([dx_series, dy_series, dz_series, series], dim=-1)
    return power_series  # (batch_size, 64, 4)
