import time

import torch

torch._dynamo.config.capture_scalar_outputs = True
torch.autograd.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')
torch.set_printoptions(sci_mode=False, threshold=200000, linewidth=200000)

from ssqr.cuda_kernel import call_matryoshka_pack, call_post_process_sparsity, call_gptq_matryoshka_mul_batched


class M8aLinear(torch.nn.Module):
    def __init__(
            self,
            bit_width: int,
            group_size: int,
            weight_int: torch.Tensor,
            weight_scale: torch.Tensor,
            weight_sparse: torch.Tensor = None,
    ):
        super().__init__()

        dtype: torch.dtype = weight_scale.dtype
        n_out_features, n_in_features = weight_int.shape
        group_size: int = group_size if group_size != n_in_features else -1
        weight_sparse: torch.Tensor = weight_sparse if weight_sparse is not None else torch.zeros(n_out_features, n_in_features, dtype=dtype)
        weight_int, weight_scale, weight_sparse = weight_int.contiguous().cpu(), weight_scale.contiguous().cpu(), weight_sparse.contiguous().cpu()
        buffer0, buffer1, buffer2, buffer3, buffer4, v_buffer = (
            torch.zeros(n_out_features * n_in_features // 32, dtype=torch.int64),
            torch.zeros(n_out_features * n_in_features // 32, dtype=torch.int32),
            torch.zeros(n_out_features * n_in_features // 32, dtype=torch.int32),
            torch.zeros(n_out_features * n_in_features // 32, dtype=torch.int64),
            torch.zeros(n_out_features * n_in_features // 32, dtype=torch.int64),
            torch.zeros(bit_width * n_out_features * n_in_features // 32, dtype=torch.int32),
        )

        csr: torch.Tensor = weight_sparse.to_sparse_csr()
        row_offsets: torch.Tensor = csr.crow_indices().to(dtype=torch.int32)
        col_vals: torch.Tensor = (csr.values().view(dtype=torch.uint16).to(dtype=torch.int32) << 16).view(dtype=torch.int32) | csr.col_indices().to(dtype=torch.int32)

        reordered_scales: torch.Tensor = torch.empty_like(weight_scale)
        row_offsets_v2: torch.Tensor = torch.zeros(n_out_features // 16 + 1, dtype=torch.int32)
        call_matryoshka_pack(bit_width, n_out_features, n_in_features, weight_int, buffer0, buffer1, buffer2, buffer3, buffer4, group_size, weight_scale, reordered_scales, v_buffer, dtype == torch.bfloat16, row_offsets, row_offsets_v2)

        col_vals_v2: torch.Tensor = torch.zeros(row_offsets_v2[-1].item(), dtype=torch.int32)
        call_post_process_sparsity(n_out_features, row_offsets, col_vals, row_offsets_v2, col_vals_v2)

        MATRYOSHKA_ASYNC, MATRYOSHKA_IS_BF16 = 1, 2

        self.n_out_features: int = n_out_features
        self.n_in_features: int = n_in_features
        self.bit_width: int = bit_width
        self.group_size: int = group_size
        self.v_buffer: torch.Tensor = torch.nn.Parameter(v_buffer, requires_grad=False)
        self.reordered_scales: torch.Tensor = torch.nn.Parameter(reordered_scales, requires_grad=False)
        self.row_offsets: torch.Tensor = torch.nn.Parameter(row_offsets_v2, requires_grad=False)
        self.col_vals: torch.Tensor = torch.nn.Parameter(col_vals_v2, requires_grad=False)
        self.nnz: int = row_offsets_v2[-1].item()
        self.flag: int = MATRYOSHKA_ASYNC + MATRYOSHKA_IS_BF16 * (dtype == torch.bfloat16)

    def forward(self, x: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
        """
        @param x: Input tensor.
        @return: A tensor resulting from a multiplication between the SSQR tensor and input tensor x.
        """
        if out is None:
            out: torch.Tensor = torch.empty(*x.shape[:-1], self.n_out_features, dtype=x.dtype, device=x.device)
        call_gptq_matryoshka_mul_batched(
            self.n_out_features,
            self.n_in_features,
            x.numel() // self.n_in_features,
            self.bit_width,
            self.group_size,
            self.v_buffer,
            self.reordered_scales,
            self.row_offsets,
            self.col_vals,
            self.nnz,
            x,#.contiguous(),
            self.flag,
            out,
            out,
        )
        return out


def test_sparse_random() -> None:
    print('UNIT TEST BEGIN')

    seed: int = 1
    torch.random.manual_seed(seed)

    def generate_cases():
        for rep in range(1):
            for bit in [2, 3, 4]:
                for m in [16, 64]:
                    for k in [1024, 4096, 12288]:
                        for n in [1, 2, 4, 8]:
                            for density in [0., .5]:
                                for is_bf16 in [False, True]:
                                    for group_size in [-1, 32, 128]:
                                        yield m, n, k, group_size, density, is_bf16, bit, rep

    for n_out_features, batch_size, n_in_features, group_size, density, is_bf16, bit_width, rep in generate_cases():
        print(f"Attempting {rep} m = {n_out_features} n = {batch_size} k = {n_in_features} density = {density} is_bf16 = {is_bf16} group_size = {group_size} bits = {bit_width}")

        do_round: bool = True
        dtype: torch.dtype = torch.bfloat16 if is_bf16 else torch.float16
        group_size: int = group_size if group_size != -1 else n_in_features

        weight_int: torch.Tensor = torch.randint(0, 2 ** bit_width, (n_out_features, n_in_features), dtype=torch.int8)
        weight_scale: torch.Tensor = torch.randn(n_out_features, n_in_features // group_size, dtype=dtype)
        weight_sparse: torch.Tensor = torch.zeros(n_out_features, n_in_features, dtype=dtype)
        weight_sparse.flatten()[torch.randperm(n_out_features * n_in_features)[:round(n_out_features * n_in_features * density)]] = torch.randn(round(n_out_features * n_in_features * density), dtype=dtype)

        x: torch.Tensor = torch.randn(batch_size, n_in_features, dtype=dtype)

        if do_round:
            weight_scale, weight_sparse, x = weight_scale.round(), weight_sparse.round(), x.round()

        weight_dq: torch.Tensor = ((weight_int.to(dtype=torch.float64).unflatten(dim=-1, sizes=(-1, group_size)) - 2 ** (bit_width - 1)) * weight_scale.to(dtype=torch.float64)[..., None]).flatten(start_dim=-2) + weight_sparse.to(dtype=torch.float64)
        m8a_linear: M8aLinear = M8aLinear(bit_width, group_size, weight_int, weight_scale, weight_sparse)

        m8a_linear, weight_dq, x = m8a_linear.cuda(), weight_dq.cuda(), x.cuda()
        y_true: torch.Tensor = (x.to(dtype=weight_dq.dtype) @ weight_dq.transpose(-2, -1)).to(dtype=x.dtype)
        y: torch.Tensor = m8a_linear.forward(x)
        torch.cuda.synchronize()
        passed: bool = torch.allclose(y, y_true)
        assert passed, f"Failed for rep = {rep} m = {n_out_features} n = {batch_size} k = {n_in_features} density = {density} is_bf16 = {is_bf16} group_size = {group_size} bits = {bit_width}\ny=\n{str(y.squeeze().t())}\ny_true=\n{str(y_true.squeeze().t())}"

    print('UNIT TEST PASSED')


def bench_sparse_random() -> None:
    print('BENCHMARK BEGIN')

    seed: int = 1
    torch.random.manual_seed(seed)

    def generate_cases():
        for rep in range(1):
            for bit in [2, 3, 4]:
                for m, k in [[4096 + 1024 * 2, 4096], [4096, 4096], [12288 * 2, 4096], [4096, 12288]]:
                        for n in [1, 2, 4, 8, 16]:
                            for density in [0., .01, .02, .03, .04, .05]:
                                for is_bf16 in [True]:
                                    for group_size in [128]:
                                        yield m, n, k, group_size, density, is_bf16, bit, rep

    results: list[dict[str, float|int|bool]] = []
    for n_out_features, batch_size, n_in_features, group_size, density, is_bf16, bit_width, rep in generate_cases():
        print(f"Attempting {rep} m = {n_out_features} n = {batch_size} k = {n_in_features} density = {density} is_bf16 = {is_bf16} group_size = {group_size} bits = {bit_width}")

        dtype: torch.dtype = torch.bfloat16 if is_bf16 else torch.float16
        group_size: int = group_size if group_size != -1 else n_in_features

        weight_int: torch.Tensor = torch.randint(0, 2 ** bit_width, (n_out_features, n_in_features), dtype=torch.int8)
        weight_scale: torch.Tensor = torch.randn(n_out_features, n_in_features // group_size, dtype=dtype)

        n_outliers_per_row: int = round(n_in_features * density)
        assert 0 <= n_outliers_per_row <= n_in_features
        vals: torch.Tensor = torch.randn(n_out_features, n_outliers_per_row, dtype=dtype).abs().clamp(min=.1)
        vals: torch.Tensor = torch.where(torch.randint(0, 2, (n_out_features, n_outliers_per_row), dtype=torch.bool), vals, -vals)
        weight_sparse: torch.Tensor = torch.zeros(n_out_features, n_in_features, dtype=dtype).scatter_(
            dim=-1,
            index=torch.rand(n_out_features, n_in_features, dtype=torch.float64).topk(n_outliers_per_row, dim=-1).indices,
            src=vals,
        )
        assert (weight_sparse.to(dtype=torch.bool).sum(dim=-1) == n_outliers_per_row).all()

        x: torch.Tensor = torch.randn(batch_size, n_in_features, dtype=dtype)

        weight_dq: torch.Tensor = ((weight_int.to(dtype=torch.float64).unflatten(dim=-1, sizes=(-1, group_size)) - 2 ** (bit_width - 1)) * weight_scale.to(dtype=torch.float64)[..., None]).flatten(start_dim=-2) + weight_sparse.to(dtype=torch.float64)
        weight_dq = weight_dq.to(dtype=x.dtype)

        m8a_linear = M8aLinear(bit_width, group_size, weight_int, weight_scale, weight_sparse)

        m8a_linear, weight_dq, x = m8a_linear.cuda(), weight_dq.cuda(), x.cuda()
        y: torch.Tensor = torch.empty(batch_size, n_out_features, dtype=x.dtype, device=x.device)
        t_m8a = basic_benchmark(lambda x=x, out=y: m8a_linear.forward(x, out))
        t_torch = basic_benchmark(lambda x=x, out=y: torch.matmul(x, weight_dq.t(), out=out))
        print(f'{t_m8a}', f'{t_torch}', f'{t_torch / t_m8a}x', sep='\t')
        results.append({
            'n_out_features': n_out_features,
            'batch_size': batch_size,
            'n_in_features': n_in_features,
            'group_size': group_size,
            'density': density,
            'is_bf16': is_bf16,
            'bit_width': bit_width,
            'rep': rep,
            't_torch': t_torch,
            't_m8a': t_m8a,
            'speedup': t_torch / t_m8a,
        })

    print('BENCHMARK END')
    print(results)


def basic_benchmark(f, n_repeats: int = 2000, n_warmup_repeats: int = 1000, measure_latency: bool = False) -> float:
    if measure_latency:
        for _ in range(n_warmup_repeats):
            f()
            torch.cuda.synchronize()
    else:  # measure throughput
        for _ in range(n_warmup_repeats):
            f()
        torch.cuda.synchronize()
    if measure_latency:
        t_start: float = time.perf_counter()
        for _ in range(n_repeats):
            f()
            torch.cuda.synchronize()
        t_end: float = time.perf_counter()
    else:  # measure throughput
        t_start: float = time.perf_counter()
        for _ in range(n_repeats):
            f()
        torch.cuda.synchronize()
        t_end: float = time.perf_counter()
    t: float = (t_end - t_start) / n_repeats
    return t


if __name__ == '__main__':
    test_sparse_random()
    bench_sparse_random()
