import logging
from typing import Callable, Literal

import torch
from optimum.quanto import AbsmaxOptimizer, qtype
from torch import dtype, Tensor

from ..utils import get_device
from .utils import entropy, LpNormDistance

logger = logging.getLogger(__name__)

__all__ = ["SymmetricEntropyOptimizer", "WrappedAbsmaxOptimizer"]


class Round(torch.autograd.Function):
    """
    Implementation of generic round operation with backward STE.
    """

    @staticmethod
    def forward(ctx, x: Tensor) -> Tensor:
        return torch.round(x)

    @staticmethod
    def backward(ctx, gy: Tensor) -> Tensor:
        return gy


class FakeFloatCast(torch.autograd.Function):
    """
    Implementation of generic round operation with backward STE.
    """

    @staticmethod
    def forward(ctx, x: Tensor, dtype) -> Tensor:
        return x.to(dtype).to(x.dtype)

    @staticmethod
    def backward(ctx, gy: Tensor) -> Tensor:
        return gy, None


class STClamp(torch.autograd.Function):
    """
    Implementation of generic round operation with backward STE.
    """

    @staticmethod
    def forward(ctx, x: Tensor, xmin, xmax) -> Tensor:
        return torch.clamp(x, xmin, xmax)

    @staticmethod
    def backward(ctx, gy: Tensor) -> Tensor:
        return gy, None, None


class SymmetricQuantizer(torch.nn.Module):
    """Implement Integer quantization as described in https://arxiv.org/pdf/2004.09602
    Support symmetric and unsymmetric quantization. Scale and zero-point can be optimized with QAT.
    Values are initialized based on min and max values in weight."""

    def __init__(self, scale: Tensor, dtype_: dtype, qtype_: qtype):
        super().__init__()
        self.dtype = dtype_
        self.qtype = qtype_
        self.log_scale = torch.nn.Parameter(torch.log2(scale.float()))

    @property
    def scale(self):
        scale = self.log_scale.to(self.dtype)
        scale = torch.ones_like(scale).ldexp(scale)
        return scale

    def dequantize(self, x: Tensor):
        # simulated the forward computation
        return x.to(self.dtype) * self.scale

    def quantize(self, x: Tensor):
        assert len(x.shape) == 2, "Only Linear Layers are supported"

        # x = x.float()  # Optional, need to verify usefulness
        # x = x / self.scale  # This line will lead to optimization issues
        x = x.ldexp(-self.log_scale)  # This line is essential for optimization
        if not self.qtype.is_floating_point:
            x = Round.apply(x)
        x = STClamp.apply(x, self.qtype.qmin, self.qtype.qmax)
        if self.qtype.is_floating_point:
            x = FakeFloatCast.apply(x, self.qtype.dtype)
        return x

    def forward(self, x: Tensor):
        x_quant = self.quantize(x)
        x_deq = self.dequantize(x_quant)
        return x_deq, x_quant


def l1_reg_fun(x: Tensor, qtype_: qtype):
    if qtype_.is_floating_point:
        return (torch.abs(x) * 4.0).mean()
    else:
        return (torch.abs(x) / 128.0).mean()


class SymmetricEntropyOptimizer(AbsmaxOptimizer):
    """
    Core optimizer implementing the entropy optimization procedure of EntQuant.
    Defaults are as used in the paper. reg_param corresponds to lambda in the paper.
    """

    def __init__(
        self,
        norm_type: Literal["absolute", "relative", "relative_entrywise", "mean"] = "relative",
        norm_p: float = 1.0,
        reg_fun: Callable[[Tensor, qtype], Tensor] = l1_reg_fun,
        reg_param: float = 4.0,
        lr: float = 1.0,
        maxiters: int = 500,
        device_compute: str | torch.device | int | None = None,
        verbose: bool = True,
    ) -> None:
        self.dist_fun = LpNormDistance(norm_type=norm_type, p=norm_p)
        self.dist_fun_eval = LpNormDistance(norm_type="relative", p=2.0)
        self.reg_fun = reg_fun
        self.reg_param = reg_param
        self.lr = lr
        self.maxiters = maxiters
        self.verbose = verbose

        if device_compute is None:
            self.device_compute = None
        else:
            self.device_compute = get_device(device_compute)

    def optimize(
        self,
        base: torch.Tensor,
        qtype_: qtype,
        axis: int | None = None,
        lr: float | None = None,  # override for retries
        reg_param: float | None = None,  # override for retries
    ) -> Tensor:
        device_host = base.device
        device_compute = self.device_compute if self.device_compute is not None else base.device

        # Get initial scales from standard AbsmaxOptimizer
        scale_orig = super(SymmetricEntropyOptimizer, self).optimize(base.to(device_compute), qtype_, axis)
        _base = base.clone().detach().to(device_compute)

        lr = self.lr if lr is None else lr
        reg_param = self.reg_param if reg_param is None else reg_param

        quantizer = SymmetricQuantizer(scale_orig.clone().detach(), dtype_=_base.dtype, qtype_=qtype_)
        quantizer = quantizer.to(device_compute)

        if self.verbose:
            with torch.no_grad():
                deq_base, quant_base = quantizer(_base)
                error = self.dist_fun((deq_base - _base), _base)
                error_eval = self.dist_fun_eval((deq_base - _base), _base)
                logger.info(
                    f"[{str(device_compute)}] Start error: {error:.3f}, Start l2: {error_eval:.3f}, "
                    f"Start entropy: {entropy(quant_base).item():.3f}"
                )

        optimizer = torch.optim.LBFGS(
            quantizer.parameters(),
            lr=lr,
            max_iter=self.maxiters,
            history_size=100,
            line_search_fn="strong_wolfe",
        )

        def closure():
            """Loss function optimized by EntQuant."""
            optimizer.zero_grad()
            _deq_base, _quant_base = quantizer(_base)

            rec_loss = self.dist_fun((_deq_base - _base), _base)
            reg_loss = reg_param * self.reg_fun(_quant_base, qtype_)
            loss = rec_loss + reg_loss
            loss.backward()
            if loss.isnan():
                raise ValueError("Loss is NaN!")
            return loss

        try:
            optimizer.step(closure)
        except ValueError as e:
            logger.warning(f"Optimization failed: {e}\n Retry with lr={lr * 0.5}")
            return self.optimize(base, qtype_, axis, lr=lr * 0.5)

        if self.verbose:
            with torch.no_grad():
                deq_base, quant_base = quantizer(_base)
                error = self.dist_fun((deq_base - _base), _base)
                error_eval = self.dist_fun_eval((deq_base - _base), _base)
                logger.info(
                    f"[{str(device_compute)}] End error: {error:.3f}, End l2: {error_eval:.3f}, "
                    f"End entropy: {entropy(quant_base).item():.3f}"
                )

        scale = quantizer.scale.clone().detach()
        return scale.to(device_host)


class WrappedAbsmaxOptimizer(AbsmaxOptimizer):
    """Wrapper around standard AbsmaxOptimizer to adapt device management."""

    def __init__(
        self,
        device_compute: str | torch.device | int | None = None,
    ) -> None:
        if device_compute is None:
            self.device_compute = None
        else:
            self.device_compute = get_device(device_compute)

    def optimize(self, base: torch.Tensor, qtype_: qtype, axis: int | None = None) -> Tensor:
        device_compute = self.device_compute if self.device_compute is not None else base.device
        scale = super(WrappedAbsmaxOptimizer, self).optimize(base.to(device_compute), qtype_, axis).to(base.device)
        # TODO: Mimics the behavior of SymmetricEntropyOptimizer, marginally better results
        scale = torch.log2(scale)
        return torch.ones_like(scale).ldexp(scale)
