import torch
import triton
import triton.language as tl

from typing import Optional, Union


def seeded_uniform(
    *size,
    seeds: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    dtype: Optional[torch.dtype] = None,
    device: Optional[Union[torch.device, str]] = None,
    pin_memory: Optional[bool] = False,
) -> torch.Tensor:
    """Similar to torch.rand, but allows for seeds to be set per row.

    seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
    If it is 3d, the additional seeds needed will be derived automatically
    in a deterministic fashion:
    [
        row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
    ]
    """
    n_dims = len(size)

    if n_dims > 3:
        raise ValueError("seeded_uniform only supports up to 3D tensors")

    if out is None:
        out = torch.empty(*size,
                          dtype=dtype,
                          device=device,
                          pin_memory=pin_memory)
    elif out.shape != size:
        raise ValueError("shape of out and size must be the same")

    if n_dims == 3:
        n_rows, n_3d, n_cols = out.shape
        stride_row = out.stride(0)
        stride_3d = out.stride(1)
    elif n_dims == 2:
        n_rows, n_cols = out.shape
        n_3d = 1
        stride_row = out.stride(0)
        stride_3d = 1
    else:
        n_cols = out.shape[0]
        n_rows = 1
        n_3d = 1
        stride_row = 1
        stride_3d = 1

    if seeds.ndim != 1:
        raise ValueError("seeds must be a 1D tensor")

    if seeds.numel() != n_rows:
        raise ValueError(
            "seeds must have the same number of elements as out has rows")

    # The philox PRNG Triton uses generates 4 random numbers at once.
    # Therefore, the most efficient use of it is to divide the
    # block size by 4, and then save the generated random numbers to
    # each of the 4 slices of the tensor.
    full_block_size = triton.next_power_of_2(n_cols)
    philox_block_size = max(full_block_size // 4, 1)
    n_slices = full_block_size // philox_block_size
    num_warps = 4
    # Manual tuning. This seems to give best performance on A100 for
    # simple kernels like this.
    if philox_block_size >= 8192:
        num_warps = 32
    elif philox_block_size >= 4096:
        num_warps = 16
    elif philox_block_size >= 2048:
        num_warps = 8

    _seeded_uniform_triton[(n_rows, n_3d)](
        out,
        seeds,
        stride_row,
        stride_3d,
        seeds.stride(0),
        n_rows,
        n_3d,
        n_cols,
        n_slices=n_slices,
        num_warps=num_warps,
        block_size=philox_block_size,
    )
    return out


@triton.jit
def _seeded_uniform_triton(
    out_ptr: torch.Tensor,
    seed_ptr: torch.Tensor,
    out_row_stride: int,
    out_3d_stride: int,
    seed_row_stride: int,
    n_rows: int,
    n_3d: int,
    n_cols: int,
    n_slices: tl.constexpr,
    block_size: tl.constexpr,
):
    """
    Generate a random float32 number in [0, 1) for each element in the output
    tensor. The random numbers in a row generated using the seed for that row.

    Args:
        out_ptr: The output tensor.
        seed_ptr: The per-row seeds to use for random number generation.
        out_row_stride: The stride between rows of the output tensor.
        out_3d_stride: The stride between 3D slices of the output tensor.
        seed_row_stride: The stride between rows of the seed tensor.
        n_rows: The number of rows in the output tensor.
        n_3d: The size of second dimension of the output tensor,
            if output tensor is 3D.
        n_cols: The number of columns in the output tensor.
        n_slices: The number of philox outputs to use.
    """
    tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")

    # Get the row index.
    row_idx = tl.program_id(axis=0)
    three_d_idx = tl.program_id(axis=1)

    philox_offsets = tl.arange(0, block_size)
    # Get the seed for the current element.
    seed = tl.load(seed_ptr + row_idx * seed_row_stride)
    if three_d_idx > 0:
        seed ^= three_d_idx
    # Generate random numbers in [0, 1).
    out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)

    output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
                            three_d_idx * out_3d_stride)
    out1_offsets = philox_offsets
    tl.store(output_row_start_ptr + out1_offsets,
             out1,
             mask=out1_offsets < n_cols)
    if n_slices > 1:
        out2_offsets = tl.arange(block_size, block_size * 2)
        tl.store(output_row_start_ptr + out2_offsets,
                 out2,
                 mask=out2_offsets < n_cols)
    if n_slices > 2:
        out3_offsets = tl.arange(block_size * 2, block_size * 3)
        tl.store(output_row_start_ptr + out3_offsets,
                 out3,
                 mask=out3_offsets < n_cols)
    if n_slices > 3:
        out4_offsets = tl.arange(block_size * 3, block_size * 4)
        tl.store(output_row_start_ptr + out4_offsets,
                 out4,
                 mask=out4_offsets < n_cols)
