import abc
import importlib
import json
import logging
import os
from abc import ABC
from dataclasses import dataclass, field
from typing import Literal, Optional, Callable, List, Type, Any
from pathlib import Path

import numpy as np
import torch
from typing_extensions import TypedDict

from bof4.serialization import (
    dump_yaml_with_arrays,
    load_yaml_with_arrays,
    NumpyJsonEncoder,
)

_logger = logging.getLogger(__name__)

import torch._dynamo

torch._dynamo.config.suppress_errors = True


def torch_safe_compile(**kwargs):
    def decorator(func):
        try:
            import triton
        except ImportError:
            _logger.warning(
                f"Cannot import triton module, skipping compilation of {func}"
            )
            return func
        try:
            compiled_func = torch.compile(func, **kwargs)
            return compiled_func
        except Exception as e:
            _logger.warning(
                f"Compilation failed: {e}. Falling back to the uncompiled version of {func}."
            )
            return func

    return decorator


def get_blockwise_absmax(weight_blocks: torch.Tensor):
    return torch.abs(weight_blocks).max(axis=1, keepdims=True).values


def get_blockwise_signed_absmax(weight_blocks: torch.Tensor):
    idx = torch.abs(weight_blocks).max(axis=1, keepdims=True).indices
    return weight_blocks.gather(1, idx)


def get_blockwise_standard_deviation(weight_blocks: torch.Tensor):
    return weight_blocks.std(axis=1, keepdims=True)


QUANT_CONSTANT_FUNCS = {
    "absmax": get_blockwise_absmax,
    "signed_absmax": get_blockwise_signed_absmax,
    "sd": get_blockwise_standard_deviation,
}

NormalizationMethod = Optional[Literal["absmax", "signed_absmax", "sd"]]


class TypeIdentifierDict(TypedDict):
    quantizer_class: str
    module: str

class QuantizerDict(TypedDict):
    type: TypeIdentifierDict
    attributes: dict[str, Any]


def get_quant_constant_func(
    normalization_method: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
    if normalization_method is None:
        return lambda a: torch.ones((a.shape[0], 1))
    return QUANT_CONSTANT_FUNCS[normalization_method]


@dataclass
class QuantState:
    """Represents a quantized tensor"""
    quant_values: torch.Tensor
    constants: torch.Tensor
    original_shape: tuple
    original_dtype: torch.dtype
    quantizer: "Quantizer"
    attributes: dict = field(default_factory=dict)

class Quantizer(ABC):
    """Base class for quantizers"""

    @abc.abstractmethod
    def to(self, device):
        raise NotImplementedError()

    @abc.abstractmethod
    def quantize(self, a: torch.Tensor) -> QuantState:
        """Quantize the input tensor `a`."""
        raise NotImplementedError()

    @abc.abstractmethod
    def dequantize(self, quant_state) -> torch.Tensor:
        """Dequantize `quant_state`."""
        raise NotImplementedError()

    @abc.abstractmethod
    def quantize_inplace(self, a: torch.Tensor):
        """Quantize a tensor in-place."""
        raise NotImplementedError()

    @abc.abstractmethod
    def _to_attributes_dict(self) -> dict[str, Any]:
        """Return a dict that can contain any nested object of lists, dicts,
         primitive types and numpy `ndarray` instances"""
        raise NotImplementedError()

    @staticmethod
    @abc.abstractmethod
    def _from_attributes_dict(dct: dict[str, Any]) -> "Quantizer":
        """Restore the quantizer from the attributes dict provided by `_to_attributes_dict`"""
        raise NotImplementedError()

    def to_dict(self) -> QuantizerDict:
        return {
            "type": {
                "quantizer_class": self.__class__.__name__,
                "module": self.__module__,
            },
            "attributes": self._to_attributes_dict(),
        }

    @staticmethod
    def from_dict(dct: QuantizerDict, overwrite_class: Type = None) -> "Quantizer":
        """load quantizer from a dict containing the type information"""
        match dct:
            case {
                "type": {
                    "quantizer_class": cls_str,
                    "module": module_str,
                },
                "attributes": data,
            }:
                pass
            case _:
                print(dct)
                raise ValueError("Dictionary does not have the correct format.")

        module = importlib.import_module(module_str)
        if overwrite_class:
            cls = overwrite_class
        else:
            cls = getattr(module, cls_str)
        return cls._from_attributes_dict(data)

    def __str__(self):
        return json.dumps(self._to_attributes_dict(), indent=4, cls=NumpyJsonEncoder)


class BlockwiseScalarQuantizer(Quantizer):
    """Implements 4-bit block-wise scalar quantization"""

    def __init__(
        self,
        codebook_points: List[float] | torch.Tensor | np.ndarray,
        block_size: int,
        normalization_method: NormalizationMethod = "absmax",
        name: Optional[str] = None,
        codebook_dtype: torch.dtype = torch.float32,
    ):
        if torch.as_tensor(codebook_points).dtype == torch.bfloat16 and codebook_dtype != torch.bfloat16:
            _logger.warning(f"Provided codebook points are in bfloat16 and will be upcast. This may lead to lower precision")
        self.codebook_points = torch.as_tensor(codebook_points, dtype=codebook_dtype)
        self.threshold_values = (
            self.codebook_points[:-1] + self.codebook_points[1:]
        ) / 2
        self._normalization_method = normalization_method
        self._compute_quant_constants = get_quant_constant_func(
            str(normalization_method)
        )
        self.name = name
        self.block_size = block_size

    @property
    def normalization_method(self) -> NormalizationMethod:
        return self._normalization_method

    @torch_safe_compile()
    @torch.inference_mode()
    def quantize(self, a: torch.Tensor):
        original_shape = a.shape
        a = a.reshape((-1, self.block_size))

        constants = self._compute_quant_constants(a)

        bins = torch.searchsorted(
            self.threshold_values, a / constants, right=True, out_int32=True
        ).to(dtype=torch.uint8)
        bins = (bins[:, ::2] << 4) + bins[:, 1::2]

        return QuantState(bins, constants, original_shape, a.dtype, self)

    @torch_safe_compile()
    @torch.inference_mode()
    def dequantize(self, quant_state) -> torch.Tensor:
        q = self.codebook_points.to(quant_state.original_dtype)
        bins_l = quant_state.quant_values >> 4
        bins_r = quant_state.quant_values & 0x0F
        dequantized = torch.empty(
            (bins_l.shape[0], bins_l.shape[1] * 2),
            dtype=quant_state.original_dtype,
            device=quant_state.quant_values.device,
        )

        dequantized[:, ::2] = q[bins_l.to(torch.int32)] * quant_state.constants
        dequantized[:, 1::2] = q[bins_r.to(torch.int32)] * quant_state.constants

        return dequantized.reshape(quant_state.original_shape)

    def _to_attributes_dict(self) -> dict:
        return {
            "codebook_points": self.codebook_points.cpu()
            .detach()
            .to(torch.float32)
            .numpy(),
            "block_size": self.block_size,
            "normalization_method": self._normalization_method,
            "name": self.name,
        }

    @staticmethod
    def _from_attributes_dict(dct: dict) -> "BlockwiseScalarQuantizer":
        return BlockwiseScalarQuantizer(**dct)

    @torch.inference_mode()
    def quantize_inplace(self, a: torch.Tensor):
        a = a.reshape((-1, self.block_size))

        # quantize
        constants = self._compute_quant_constants(a)
        bins = torch.searchsorted(self.threshold_values, a / constants, right=True)
        # dequantize
        a[:] = self.codebook_points[bins] * constants

    def to(self, device):
        self.codebook_points = self.codebook_points.to(device)
        self.threshold_values = self.threshold_values.to(device)


class OutlierPreservingQuantizer(Quantizer):
    """Block-wise quantizer that additionally stores values that are more
    than `threshold` standard deviations away from zero in bfloat16 precision"""

    def __init__(
        self, quantizer: BlockwiseScalarQuantizer, threshold: float = 4.0, name=None
    ):
        self.quantizer = quantizer
        self.threshold = threshold
        self._name = name
        self.outlier_count = 0

    @property
    def block_size(self) -> int:
        return self.quantizer.block_size

    @property
    def name(self) -> str:
        return self._name or self.quantizer.name

    @torch_safe_compile()
    @torch.inference_mode()
    def quantize(self, a: torch.Tensor) -> QuantState:
        flat_view = a.flatten()
        a_blocks = a.reshape(-1, self.block_size)
        block_stds = a_blocks.std(dim=1, unbiased=True, keepdim=True)
        normalized_tensor = (a_blocks / block_stds).flatten()

        outliers_indices = (torch.abs(normalized_tensor) > self.threshold).nonzero(
            as_tuple=True
        )[0]

        outliers = flat_view[outliers_indices].clone()
        flat_view[outliers_indices] = 0.0

        quant_state = self.quantizer.quantize(a)
        quant_state.attributes["outliers"] = outliers.to(torch.bfloat16)

        quant_state.attributes["outliers_indices"] = outliers_indices
        self.outlier_count += quant_state.attributes["outliers"].numel()
        return quant_state

    @torch_safe_compile()
    @torch.inference_mode()
    def dequantize(self, quant_state: QuantState) -> torch.Tensor:
        dequantized = self.quantizer.dequantize(quant_state)
        dequantized.flatten()[quant_state.attributes["outliers_indices"]] = (
            quant_state.attributes["outliers"]
        )
        return dequantized

    @torch.inference_mode()
    def quantize_inplace(self, a: torch.Tensor):
        return self.dequantize(self.quantize(a))

    def to(self, device):
        self.quantizer.to(device)

    def _to_attributes_dict(self) -> dict:
        return dict(
            quantizer=self.quantizer.to_dict(), threshold=self.threshold, name=self.name
        )

    @staticmethod
    def _from_attributes_dict(dct: dict) -> "OutlierPreservingQuantizer":
        quantizer = Quantizer.from_dict(dct["quantizer"])
        return OutlierPreservingQuantizer(quantizer, dct["threshold"], name=dct["name"])


class ClippingQuantizer(Quantizer):
    """Block-wise quantizer that clips values at a threshold before quantization. Values are normalized with their block-wise SD before clipping."""

    def __init__(
        self, quantizer: BlockwiseScalarQuantizer, threshold: float = 4.0, name=None
    ):
        self.quantizer = quantizer
        self.threshold = threshold
        self._name = name

    @property
    def block_size(self) -> int:
        return self.quantizer.block_size

    @property
    def name(self) -> str:
        return self._name or self.quantizer.name

    @torch_safe_compile()
    @torch.inference_mode()
    def quantize(self, a: torch.Tensor) -> QuantState:
        original_shape = a.shape
        a_blocks = a.reshape(-1, self.block_size)
        block_stds = a_blocks.std(dim=1, unbiased=True)
        thresholds = block_stds * self.threshold


        a_blocks = torch.clamp(a_blocks.t(), min=-thresholds, max=thresholds).t()

        a = a_blocks.reshape(original_shape)

        quant_state = self.quantizer.quantize(a)

        return quant_state

    @torch_safe_compile()
    @torch.inference_mode()
    def dequantize(self, quant_state: QuantState) -> torch.Tensor:
        return self.quantizer.dequantize(quant_state)

    @torch.inference_mode()
    def quantize_inplace(self, a: torch.Tensor):
        return self.dequantize(self.quantize(a))

    def to(self, device):
        self.quantizer.to(device)

    def _to_attributes_dict(self) -> dict:
        return dict(
            quantizer=self.quantizer.to_dict(), threshold=self.threshold, name=self.name
        )

    @staticmethod
    def _from_attributes_dict(dct: dict) -> "ClippingQuantizer":
        quantizer = Quantizer.from_dict(dct["quantizer"])
        return ClippingQuantizer(quantizer, dct["threshold"], name=dct["name"])


@torch.jit.script
def _assign_to_nearest_neighbor_torch(
    a: torch.Tensor, recons_vecs: torch.Tensor, order: int = 2, chunk_size: int = 2**20
) -> torch.Tensor:
    result = torch.empty(a.shape[0], dtype=torch.uint8, device=a.device)
    for i in range(0, int(a.shape[0]), chunk_size):
        result[i : i + chunk_size] = torch.argmin(
            torch.linalg.norm(
                a[i : i + chunk_size][:, None, :] - recons_vecs, ord=int(order), dim=2
            ),
            dim=1,
        )
    return result


@torch.no_grad()
def dequantize(quant_state: QuantState) -> torch.Tensor:
    return quant_state.quantizer.dequantize(quant_state)


def save_to_file(
    quantizer: Quantizer, path: str | os.PathLike, overwrite_existing: bool = False
):
    if Path(path).exists() and not overwrite_existing:
        raise FileExistsError(
            f"File {path} already exists. Set `overwrite_existing=True` to overwrite.`"
        )
    dump_yaml_with_arrays(quantizer.to_dict(), path)


def load_from_file(path: str | os.PathLike, overwrite_class=None) -> Quantizer:
    data = load_yaml_with_arrays(path)
    return Quantizer.from_dict(data, overwrite_class=overwrite_class)
