import numpy as np
from nn_compression.quantisation import (
    FixedGridQuantiser,
    AffineGridQuantiser,
    VectorQuantiser,
)
import torch
from nn_compression.quantisation._grids import RatedGrid
import pytest

lm_pv_values = [
    (torch.tensor(1.0), torch.tensor(1.0)),
    (torch.tensor(0.5), torch.tensor(3.0)),
    (torch.tensor(1e-6), torch.tensor(200.0)),
    (torch.tensor(2.0), torch.tensor(200.0)),
]


def _legacy_quantise_affine(x, scale, zero, maxq):
    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)


class _LegacyGptqAffineGridQuantiser:
    def __init__(
        self,
        bits,
        shape=1,
        perchannel=False,
        sym=True,
        mse=False,
        norm=2.4,
        grid=100,
        maxshrink=0.8,
        trits=False,
    ):
        super().__init__()
        self.maxq = torch.tensor(2**bits - 1)
        self.perchannel = perchannel
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.grid = grid
        self.maxshrink = maxshrink
        if trits:
            self.maxq = torch.tensor(-1)

    def find_params(self, x, weight=False):
        dev = x.device
        self.maxq = self.maxq.to(dev)

        shape = x.shape
        if self.perchannel:
            if weight:
                x = x.flatten(1)
            else:
                if len(shape) == 4:
                    x = x.permute([1, 0, 2, 3])
                    x = x.flatten(1)
                if len(shape) == 3:
                    x = x.reshape((-1, shape[-1])).t()
                if len(shape) == 2:
                    x = x.t()
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmin < 0
            if torch.any(tmp):
                xmin[tmp] = -xmax[tmp]
        tmp = (xmin == 0) & (xmax == 0)
        xmin[tmp] = -1
        xmax[tmp] = +1

        if self.maxq < 0:
            self.scale = xmax
            self.zero = xmin
        else:
            self.scale = (xmax - xmin) / self.maxq
            if self.sym:
                self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
            else:
                self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            best = torch.full([x.shape[0]], float("inf"), device=dev)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / self.maxq
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
                q = _legacy_quantise_affine(
                    x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
                )
                q -= x
                q.abs_()
                q.pow_(self.norm)
                err = torch.sum(q, 1)
                tmp = err < best
                if torch.any(tmp):
                    best[tmp] = err[tmp]
                    self.scale[tmp] = scale1[tmp]
                    self.zero[tmp] = zero1[tmp]
        if not self.perchannel:
            if weight:
                tmp = shape[0]
            else:
                tmp = shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        if weight:
            shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(shape)
            self.zero = self.zero.reshape(shape)
            return
        if len(shape) == 4:
            self.scale = self.scale.reshape((1, -1, 1, 1))
            self.zero = self.zero.reshape((1, -1, 1, 1))
        if len(shape) == 3:
            self.scale = self.scale.reshape((1, 1, -1))
            self.zero = self.zero.reshape((1, 1, -1))
        if len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantise(self, x, row=None):
        if self.ready():
            if len(x.shape) <= 1:
                return _legacy_quantise_affine(
                    x,
                    self.scale.flatten()[0],
                    self.zero.flatten()[0],
                    self.maxq.flatten()[0],
                )
            else:
                return _legacy_quantise_affine(x, self.scale, self.zero, self.maxq)
        else:
            raise ValueError("Quantizer not ready.")

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)


def test_affine_grid_quantiser():
    x = torch.Tensor([-4, 3.0, 2.0, -1.6, 1.1, 3])
    quant = AffineGridQuantiser(3, symmetric=False).quantise(x)
    assert torch.allclose(quant.x, torch.Tensor([-4, 3.0, 2.0, -2, 1.0, 3]))


def test_ab_affine_grid_quantiser():
    q = AffineGridQuantiser(5, symmetric=True)
    x = torch.tensor([-3.2, -2.1, 0.0, 1.5, 2.8])
    assert torch.allclose(
        q.quantise(x).x,
        torch.tensor([-3.3032, -2.0645, 0.0, 1.4452, 2.8903]),
        atol=1e-3,
    )


def test_compatibility_gptq():
    x = torch.Tensor([[-4, 3.0, 2.0, 0.1, -1.6, 1.1, 3]])

    quant = AffineGridQuantiser(3, symmetric=False).quantise(x)
    quant2 = _LegacyGptqAffineGridQuantiser(3, sym=False)
    quant2.find_params(x)
    assert torch.allclose(quant.x, quant2.quantise(x))

    quant = AffineGridQuantiser(3, symmetric=True).quantise(x)
    quant2 = _LegacyGptqAffineGridQuantiser(3, sym=True)
    quant2.find_params(x)
    assert torch.allclose(quant.x, quant2.quantise(x))


def test_fixed_grid_quantisation():
    q = FixedGridQuantiser(torch.tensor([-3.2, -2.1, 0.0, 1.5, 2.8]))
    x = torch.tensor([-128.0, -3.0, -2.0, 0.3, 1.6, 2.1, 2.2, 27.0])
    assert torch.allclose(
        q.quantise(x).x,
        torch.tensor([-3.2, -3.2, -2.1, 0.0, 1.5, 1.5, 2.8, 2.8]),
        atol=1e-3,
    )


def test_vector_quantisation_1d():
    x = torch.tensor([1.0, 2.0])
    q = VectorQuantiser(1)
    q.find_params(x)
    assert torch.allclose(q.quantise(x).x, torch.tensor([1.0, 2.0]))

    x = torch.tensor([1.0, 2.0, 1.0, 2.0, 8.0, 9.0, 9.0, 8.0])
    q = VectorQuantiser(1)
    q.find_params(x)
    assert torch.allclose(
        q.quantise(x).x, torch.tensor([1.5, 1.5, 1.5, 1.5, 8.5, 8.5, 8.5, 8.5])
    )


def test_vector_quantisation_high_d():
    torch.manual_seed(0)
    np.random.seed(0)
    d = 10
    n = 1000
    shift = torch.zeros(d)
    shift[0] = 10

    shift2 = torch.zeros(d)
    shift2[1] = 10

    shift3 = torch.zeros(d)
    shift3[2] = 10

    xs1 = torch.randn(n, d) * 0.1
    xs2 = torch.randn(n, d) * 0.1 + shift.broadcast_to((n, d))
    xs3 = torch.randn(n, d) * 0.1 + shift2.broadcast_to((n, d))
    xs4 = torch.randn(n, d) * 0.1 + shift3.broadcast_to((n, d))

    xs = torch.cat([xs1, xs2, xs3, xs4], dim=0)

    xs1_exp = torch.zeros_like(xs1)
    xs2_exp = torch.zeros_like(xs2) + shift.broadcast_to((n, d))
    xs3_exp = torch.zeros_like(xs3) + shift2.broadcast_to((n, d))
    xs4_exp = torch.zeros_like(xs4) + shift3.broadcast_to((n, d))
    means = torch.cat([xs1_exp, xs2_exp, xs3_exp, xs4_exp], dim=0)

    q = VectorQuantiser(2)
    q.find_params(xs)
    quant = q.quantise(xs)
    assert (quant.x - means).abs().mean() < 0.01

    assert quant.grid.gridpoints.shape == (4, d)
