from __future__ import annotations
import logging
from typing import Union
import torch


class NFQuantizer:
    def __init__(self, num_bits=4, device="cuda", block_size=64, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_bits = num_bits
        self.device = device
        self.block_size = block_size
        self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits)
        self.norm_lookup_table = self.norm_lookup_table.to(device)

    @staticmethod
    def create_normal_map(offset=0.9677083, symmetric=False, num_bits=4):
        try:
            from scipy.stats import norm
        except ImportError:
            raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")

        variations = 2**num_bits
        if symmetric:
            v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist()
            values = []
            for index in range(len(v) - 1):
                values.append(0.5 * v[index] + 0.5 * v[index + 1])
            v = values
        else:
            v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist()
            v2 = [0]
            v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist()
            v = v1 + v2 + v3

        values = torch.Tensor(v)
        values = values.sort().values
        values /= values.max()
        return values

    def quantize_block(self, weight):
        if len(weight.shape) != 2:
            raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.")
        if weight.shape[0] * weight.shape[1] % self.block_size != 0:
            raise ValueError(
                f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) "
                f"is not divisible by block size {self.block_size}."
            )

        weight_flatten = weight.flatten()  # (M*N, )
        weight_block = weight_flatten.reshape(-1, self.block_size)  # (L, B), L = M * N / B
        weight_max = weight_block.abs().max(dim=-1)[0]  # (L)
        weight_max = weight_max.unsqueeze(-1)  # (L, 1)
        weight_divabs = weight_block / weight_max  # (L, B)
        L_reshaped = self.norm_lookup_table.reshape(1, -1)  # (1, 2**K)

        abs_diff = torch.abs(weight_divabs.unsqueeze(-1) - L_reshaped)  # (L, B, 2**K)
        qweight = torch.argmin(abs_diff, dim=-1)  # (L, B)

        return qweight, weight_max, weight.shape

    def dequantize_block(self, qweight, weight_max, original_shape):
        L_reshaped = self.norm_lookup_table.reshape(-1)
        qweight_real = L_reshaped[qweight]
        weight = qweight_real * weight_max

        return weight.reshape(original_shape)


def _low_rank_decomposition(weight, reduced_rank=32):
    """
    :param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return:
    """
    matrix_dimension = len(weight.size())
    if matrix_dimension != 2:
        raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.")

    U, S, Vh = torch.linalg.svd(weight, full_matrices=False)

    L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank]))
    R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh

    return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank}


@torch.no_grad()
def exsqf_init(weight: Union[torch.Tensor, torch.nn.Parameter],
               lora_C: Union[torch.Tensor, torch.nn.Parameter, None],
               num_bits: int, reduced_rank: int, num_iter=5):
    if num_iter <= 0:
        raise ValueError("Number of iterations must be greater than 0")

    out_feature, in_feature = weight.size()
    device = weight.device
    dtype = weight.dtype

    logging.info(
        f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} "
        f"| Num Iter: {num_iter} | Num Bits: {num_bits}"
    )
    compute_device = device
    quantizer = NFQuantizer(num_bits=num_bits, device=device)

    reduced_rank = int(reduced_rank)
    weight = weight.to(device=compute_device, dtype=torch.float32)
    L, R = 0, 0
    for i in range(num_iter):
        L = lora_C @ L if torch.is_tensor(L) else 0
        low_rank_product = L @ R if torch.is_tensor(L) else 0
        residual = weight - low_rank_product

        quantized_weight, max_abs, shape = quantizer.quantize_block(residual)
        dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape)

        output = _low_rank_decomposition(weight - dequantized_weight, reduced_rank=reduced_rank)
        L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]

    lora_A, lora_B = R, L

    return dequantized_weight.to(device=device, dtype=dtype), lora_A, lora_B
