from pathlib import Path

import torch
import torch.utils.cpp_extension
from torch import Tensor

import torch.lib

from .gemv_lib import lib, lib_ops


lib.define("fast_gemv(Tensor x, Tensor W, int w_zero, int num_threads_per_row, int cols_per_warp) -> Tensor")


def fast_gemv(x: Tensor, W: Tensor, w_zero: int, num_threads_per_row: int, cols_per_warp: int) -> Tensor:
    assert x.is_cuda and x.ndim == 2 and x.dtype is torch.uint8 and x.is_contiguous()
    assert W.is_cuda and W.ndim == 2 and W.dtype is torch.uint8 and W.is_contiguous()
    return lib_ops.fast_gemv(x, W, w_zero, num_threads_per_row, cols_per_warp)


@torch.library.impl(lib, "fast_gemv", "Meta")
def _(x: Tensor, W: Tensor, w_zero: int, num_threads_per_row: int, cols_per_warp: int) -> Tensor:
    return torch.empty((x.shape[0], W.shape[1]), device=x.device, dtype=torch.int32)


@torch.library.impl(lib, "fast_gemv", "CUDA")
def _(x: Tensor, W: Tensor, w_zero: int, num_threads_per_row: int, cols_per_warp: int) -> Tensor:
    import fastgemv_lib

    return fastgemv_lib.fast_gemv(x, W, w_zero, num_threads_per_row, cols_per_warp)