from typing import Any, Tuple, Dict
import torch
from .base_quantizer import BaseQuantizer


def pack_bits(indices: torch.Tensor, n_bits: int) -> torch.ByteTensor:
    if n_bits < 1 or n_bits > 8:
        raise ValueError("n_bits must be in [1, 8].")
    mask_val = (1 << n_bits) - 1
    values_per_byte = 8 // n_bits

    flat = indices.reshape(-1).to(torch.int64)
    if torch.any(flat < 0) or torch.any(flat > mask_val):
        raise ValueError(f"indices out of range [0, {mask_val}] for n_bits={n_bits}")

    pad_size = (-flat.numel()) % values_per_byte
    if pad_size > 0:
        flat = torch.cat([flat, torch.zeros(pad_size, dtype=flat.dtype, device=flat.device)])
    flat = flat.view(-1, values_per_byte)

    packed = torch.zeros(flat.shape[0], dtype=torch.uint8, device=flat.device)
    mask = torch.tensor(mask_val, dtype=torch.uint8, device=flat.device)
    for i in range(values_per_byte):
        vals_uint8 = flat[:, i].to(torch.uint8) & mask
        shift = i * n_bits
        part = (vals_uint8 << shift)
        packed |= part
    return packed


def unpack_bits(packed: torch.ByteTensor, n_bits: int, length: int) -> torch.Tensor:
    if n_bits < 1 or n_bits > 8:
        raise ValueError("n_bits must be in [1, 8].")
    values_per_byte = 8 // n_bits
    mask_val = (1 << n_bits) - 1
    mask = torch.tensor(mask_val, dtype=packed.dtype, device=packed.device)

    parts = []
    for shift_idx in range(values_per_byte):
        vals = (packed >> (shift_idx * n_bits)) & mask
        parts.append(vals.to(torch.int64))
    flat = torch.stack(parts, dim=1).reshape(-1)[:length]
    return flat.to(torch.int64)


class UniformQuantizer(BaseQuantizer):
    """
    Implements a standard group-wise uniform quantizer with bit-packing.
    """

    def __init__(self, n_bits: int, group_size: int = 128, symmetric: bool = True):
        super().__init__(n_bits)
        if group_size <= 0:
            raise ValueError("Group size must be a positive integer.")
        self.group_size = group_size
        self.symmetric = symmetric

        if self.symmetric:
            self.q_min = -2 ** (self.n_bits - 1)
            self.q_max = 2 ** (self.n_bits - 1) - 1
        else:
            self.q_min = 0
            self.q_max = 2 ** self.n_bits - 1

        self.num_levels = self.q_max - self.q_min + 1

    def quantize(self, fp32_tensor: torch.Tensor) -> Tuple[torch.ByteTensor, Dict[str, Any]]:
        original_shape = fp32_tensor.shape
        flat_tensor = fp32_tensor.reshape(-1)
        remainder = flat_tensor.numel() % self.group_size
        pad_size = (self.group_size - remainder) % self.group_size
        if pad_size > 0:
            pad = torch.zeros(pad_size, device=flat_tensor.device, dtype=flat_tensor.dtype)
            flat_tensor = torch.cat([flat_tensor, pad], dim=0)

        reshaped_tensor = flat_tensor.reshape(-1, self.group_size)

        # scales and zero-points
        max_vals = torch.max(torch.abs(reshaped_tensor), dim=1, keepdim=True)[0]
        scales = max_vals / ((self.num_levels - 1) / 2)
        scales = scales + 1e-12
        zero_points = torch.zeros_like(scales)

        # quantization [0..num_levels-1]
        q = torch.round(reshaped_tensor / scales).to(torch.int64)
        q = torch.clamp(q, self.q_min, self.q_max)
        indices = (q - self.q_min).to(torch.int64)

        flat_indices = indices.reshape(-1)

        packed = pack_bits(flat_indices, self.n_bits)

        state = {
            "scales": scales.half(),
            "zero_points": zero_points,
            "original_shape": original_shape,
            "pad_size": pad_size,
            "numel": flat_indices.numel(),
        }
        return packed, state

    def dequantize(self, quantized_tensor: torch.ByteTensor, state: Dict[str, Any]) -> torch.Tensor:
        scales = state["scales"]
        original_shape = state["original_shape"]
        pad_size = state["pad_size"]
        numel = state["numel"]

        indices = unpack_bits(quantized_tensor, self.n_bits, numel)
        q = indices + self.q_min

        reshaped_q = q.reshape(-1, self.group_size).half()
        dequantized = (reshaped_q - state["zero_points"]) * scales

        flat = dequantized.reshape(-1)
        if pad_size > 0:
            flat = flat[:-pad_size]
        return flat.reshape(original_shape)
