"""
Benford-Quant quantizer with group-wise support and true memory reduction
via bit-packing (2..8 bits per value).
"""

from typing import Any, Tuple, Dict, Optional
import torch
import numpy as np
from .base_quantizer import BaseQuantizer

def pack_bits(indices: torch.Tensor, n_bits: int) -> torch.ByteTensor:
    if n_bits < 1 or n_bits > 8:
        raise ValueError("n_bits must be in [1, 8].")
    mask_val = (1 << n_bits) - 1
    values_per_byte = 8 // n_bits
    if values_per_byte < 1:
        values_per_byte = 1

    flat = indices.reshape(-1).to(torch.int64)
    
    if torch.any(flat < 0) or torch.any(flat > mask_val):
        raise ValueError(f"indices out of range [0, {mask_val}] for n_bits={n_bits}")

    pad_size = (-flat.numel()) % values_per_byte
    if pad_size > 0:
        flat = torch.cat([flat, torch.zeros(pad_size, dtype=flat.dtype, device=flat.device)])
    flat = flat.view(-1, values_per_byte)

    packed = torch.zeros(flat.shape[0], dtype=torch.uint8, device=flat.device)
    mask = torch.tensor(mask_val, dtype=torch.uint8, device=flat.device)
    for i in range(values_per_byte):
        vals_uint8 = flat[:, i].to(torch.uint8) & mask
        shift = i * n_bits
        part = (vals_uint8 << shift)
        packed |= part
    return packed


def unpack_bits(packed: torch.ByteTensor, n_bits: int, length: int) -> torch.Tensor:
    if n_bits < 1 or n_bits > 8:
        raise ValueError("n_bits must be in [1, 8].")
    values_per_byte = 8 // n_bits
    if values_per_byte < 1:
        values_per_byte = 1

    mask_val = (1 << n_bits) - 1
    mask = torch.tensor(mask_val, dtype=packed.dtype, device=packed.device)

    parts = []
    for shift_idx in range(values_per_byte):
        vals = (packed >> (shift_idx * n_bits)) & mask
        parts.append(vals.to(torch.int64))
    flat = torch.stack(parts, dim=1).reshape(-1)[:length]
    return flat.to(torch.int64)


class BenfordQuantizer(BaseQuantizer):
    def __init__(
        self,
        n_bits: int,
        group_size: int = 128,
        level_distribution: str = "log-uniform",
        min_level: Optional[float] = None,
        min_exponent: Optional[int] = None,
        range: Optional[int] = None,
    ):
        super().__init__(n_bits)
        if group_size <= 0:
            raise ValueError("Group size must be positive.")
        self.group_size = group_size
        self.level_distribution = level_distribution

        if min_exponent is not None and range is not None:
            raise ValueError("Specify either 'range' or 'min_exponent', not both.")
        if range is not None:
            min_exponent = -abs(range)

        if min_exponent is not None:
            self.min_exponent = int(min_exponent)
            self.min_level = float(10.0 ** self.min_exponent)
        else:
            if min_level is None:
                min_level = 1e-3
            if min_level <= 0:
                raise ValueError("min_level must be > 0.")
            self.min_level = float(min_level)
            self.min_exponent = int(np.floor(np.log10(self.min_level)))

        self.quantization_levels = self._generate_levels()

    @staticmethod
    def estimate_range(tensor: torch.Tensor, method: str = "percentile", value: float = 0.999) -> int:
        abs_vals = tensor.abs().reshape(-1)
        if abs_vals.numel() == 0:
            return -3
        max_val = abs_vals.max()
        if max_val == 0:
            return -3

        if method == "percentile":
            abs_vals_np = abs_vals.cpu().numpy()
            pct = np.quantile(abs_vals_np, value)
            ratio = pct / max_val.item()
        elif method == "variance":
            ratio = abs_vals.var().sqrt().item() / max_val.item()
        else:
            raise ValueError(f"Unknown calibration method: {method}")

        ratio = max(ratio, 1e-12)
        return int(np.floor(np.log10(ratio)))

    def _generate_levels(self) -> torch.Tensor:
        # Exactly 2**n_bits levels
        num_levels = 2 ** self.n_bits
        half = num_levels // 2

        if self.n_bits <= 3: #No explicit 0
            if self.level_distribution == "log-uniform":
                log_space = torch.linspace(np.log(self.min_level), np.log(1.0), half)
                pos = torch.exp(log_space)
            elif self.level_distribution == "linear":
                pos = torch.linspace(self.min_level, 1.0, half)
            else:
                raise ValueError(f"Unknown level_distribution: {self.level_distribution}")

            levels = torch.cat([-torch.flip(pos, [0]), pos])  # tamanho = 2*half = num_levels

        else: # Explicit 0
            if self.level_distribution == "log-uniform":
                log_space = torch.linspace(np.log(self.min_level), np.log(1.0), half - 1)
                pos = torch.exp(log_space)
            elif self.level_distribution == "linear":
                pos = torch.linspace(self.min_level, 1.0, half - 1)
            else:
                raise ValueError(f"Unknown level_distribution: {self.level_distribution}")

            # Level concat
            levels = torch.cat([
                -torch.flip(pos, [0]),   # negatives
                torch.tensor([0.0]),     # 0
                pos                      # positives
            ])
        
        return levels.to(torch.float16)

    def quantize(self, fp32_tensor: torch.Tensor) -> Tuple[torch.ByteTensor, Dict[str, Any]]:
        original_shape = fp32_tensor.shape
        flat = fp32_tensor.reshape(-1)
        remainder = flat.numel() % self.group_size
        pad_size = (self.group_size - remainder) % self.group_size
        if pad_size > 0:
            pad = torch.zeros(pad_size, device=flat.device, dtype=flat.dtype)
            flat = torch.cat([flat, pad], dim=0)

        reshaped = flat.reshape(-1, self.group_size)

        scales = torch.max(torch.abs(reshaped), dim=1, keepdim=True)[0]
        scales = scales + 1e-12

        normalized = reshaped / scales

        device_levels = self.quantization_levels.to(fp32_tensor.device)
        indices = torch.bucketize(normalized, device_levels)
        left_indices = torch.clamp(indices - 1, 0, device_levels.numel() - 1)
        right_indices = torch.clamp(indices, 0, device_levels.numel() - 1)
        left_vals = device_levels[left_indices]
        right_vals = device_levels[right_indices]
        choose_right = torch.abs(normalized - right_vals) < torch.abs(normalized - left_vals)
        quantized_indices = torch.where(choose_right, right_indices, left_indices).to(torch.int64)

        # clamp to interval [0, 2**n_bits - 1]
        max_idx = device_levels.numel() - 1
        quantized_indices = torch.clamp(quantized_indices, 0, max_idx)

        # flatten to bit pack
        flat_indices = quantized_indices.reshape(-1)

        # pack
        packed_indices = pack_bits(flat_indices, self.n_bits)

        state = {
            "scales": scales.half(),  # FP16 scales
            "original_shape": original_shape,
            "pad_size": pad_size,
            "numel": flat_indices.numel(),
        }
        return packed_indices, state

    def dequantize(self, quantized_tensor: torch.ByteTensor, state: Dict[str, Any]) -> torch.Tensor:
        scales = state["scales"]
        original_shape = state["original_shape"]
        pad_size = state.get("pad_size", 0)
        numel = state["numel"]

        # float 16 + on device (gpu)
        scales = scales.to(quantized_tensor.device, torch.float16)

        device_levels = self.quantization_levels.to(quantized_tensor.device)

        quantized_indices = unpack_bits(quantized_tensor, self.n_bits, numel).to(torch.long)

        dequantized_normalized = device_levels[quantized_indices]

        reshaped = dequantized_normalized.reshape(-1, self.group_size)
        rescaled = reshaped * scales

        flat = rescaled.reshape(-1)
        if pad_size > 0:
            flat = flat[:-pad_size]
        return flat.reshape(original_shape)
