import torch
import gc
import logging
import torch
import numpy as np
import random

logger = logging.getLogger(__name__)


def freeze_random(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    g = torch.Generator()
    g.manual_seed(0)


freeze_random()


DEBUG = 0
if not DEBUG:
    print = lambda *args, **kwargs: None


nbits = 4
batch_size = 1
shapes = [
    (1, 4096, 4096),  # llama-7b
    (1, 14336, 4096),
    (1, 8192, 8192),  # llama-70b
    (1, 28672, 8192),
    (1, 4096, 4096),
    (1, 1024, 1024),
    # (1,128, 128),
    # (batch_size, 32, 128),
]


# import fastgemv_lib
# from fastgemv.gemv_lib import lib, lib_ops
# func = fastgemv_lib.fast_gemv
from fastgemv import *

func = torch.ops.pre_quant.fast_gemv


def torch_ref(example_intput, weight_packed, w_shift):
    out_features, in_features = weight_packed.shape
    unpacked_weight = torch.zeros((out_features, in_features * 2), dtype=torch.int8, device="cuda")
    unpacked_weight[:, 0::2] = (weight_packed >> 4) & 0b00001111
    unpacked_weight[:, 1::2] = weight_packed & 0b00001111
    unpacked_input = torch.zeros((1, in_features * 2), dtype=torch.int8, device="cuda")
    unpacked_input[:, 0::2] = (example_intput >> 4) & 0b00001111
    unpacked_input[:, 1::2] = example_intput & 0b00001111
    return torch.matmul(unpacked_input.float() - w_shift, unpacked_weight.T.float() - w_shift)


def pack_uint4_to_uint8(uint4_tensor):
    result = torch.zeros(
        (uint4_tensor.shape[0], uint4_tensor.shape[1] // 2), dtype=torch.uint8, device=uint4_tensor.device
    )
    result = (uint4_tensor[:, 0::2] & 0b00001111) << 4
    result |= uint4_tensor[:, 1::2] & 0b00001111
    return result.view(torch.uint8)


for cols_per_warp in [4, 8, 16]:
    # for cols_per_warp in [32]:
    for shape_pair in shapes[::-1]:
        logger.info(f"========================= shape_pair: {shape_pair}")
        bs, in_features, out_features = shape_pair
        quant_min, quant_max = 0, 15
        zero_shift = 0
        uint4_input = torch.randint(quant_min, quant_max, (1, in_features), dtype=torch.uint8).cuda().contiguous()
        uint4_weight = (
            torch.randint(quant_min, quant_max, (out_features, in_features), dtype=torch.uint8).cuda().contiguous()
        )
        print(f"int4_input: {uint4_input}")
        uint4_input_packed = pack_uint4_to_uint8(uint4_input)
        print(f"int4_input_packed: {uint4_input_packed}")
        print(f"int4_weight[0]: {uint4_weight[0]}")
        uint4_weight_packed = pack_uint4_to_uint8(uint4_weight)
        print(f"int4_weight_packed[0]: {uint4_weight_packed[0]}")
        result_ref = torch_ref(uint4_input_packed, uint4_weight_packed, zero_shift)
        num_threads_per_row = 32
        num_threads_per_row = 4 if in_features <= 128 else 32

        out = func(uint4_input_packed, uint4_weight_packed, zero_shift, num_threads_per_row, cols_per_warp)

        print("result_ref:", result_ref, result_ref.shape)
        print("out:", out, out.shape)
        if not torch.equal(result_ref, out.float()):
            logger.error(f"For shape_pair: {shape_pair}, expected: {result_ref}, got: {out}")
        else:
            logger.info(f"For shape_pair: {shape_pair}, cols_per_warp: {cols_per_warp} result are same")
        torch.cuda.empty_cache()
        gc.collect()
