import torch
from .compressor import Compressor

COMP_PARAM_NAME="bits"

class Quantizer(Compressor):
    def __init__(self, bits, bucket_size, enable_error_correction=False, warmup_steps=None, compress_incomplete=False):
        super().__init__(enable_error_correction, warmup_steps)
        self.bits = bits
        self.bits_default = bits
        self.num_levels = 1 << bits
        self.bucket_size = bucket_size
        self.bucket_size_default = bucket_size
        self.compress_incomplete = compress_incomplete
        self.values_set = list(range(2, self.bits_default + 1))

    def quantize_bucket(self, a):
        raise NotImplementedError
    def get_bucket_size(self, b):
        return self.bucket_size_default

    def get_compressed_size(self, t, b):
        numel = t.numel()
        if b < 1 or b > 8:
            return numel
        d = numel // self.get_bucket_size(b)
        elem_size = 32
        if t.dtype == torch.float16:
            elem_size = 16
        bits = elem_size * 2 * d + b * numel
        return bits // elem_size

    def do_adjust(self, states):
        return self.adjuster.fit_predict(states, self.values_set)

    def add_adjuster(self, adjuster, values_set=None):
        assert values_set is not None
        assert isinstance(values_set, list), "Quantizer.add_adjuster: Values set must be a list"
        for v in values_set:
            assert v >= 1 and v <= 8, "Quantizer.add_adjuster: bits values must be in the range [1, 8]"
        super().add_adjuster(adjuster, values_set)

    # def get_compression_scheme(self):
    #     if not self.is_adaptive:
    #         return
    #     d = {}
    #     for p, state in self.states.items():
    #         bits = state[COMP_PARAM_NAME]
    #         name = state["name"]
    #         d[name] = bits
    #     return d

    # def get_reverse_compression_scheme(self):
    #     if not self.is_adaptive:
    #         return
    #     d = {}
    #     for p, state in self.states.items():
    #         bits = state[COMP_PARAM_NAME]
    #         name = state["name"]
    #         if bits in d:
    #             d[bits].append(name)
    #         else:
    #             d[bits] = [name]
    #     return d

    def set_states_by_compression_scheme(self, states, bits_array):
        assert len(states) == len(bits_array), "Unequal sizes of states and bits array"
        for state, bit in zip(states, bits_array):
            state[COMP_PARAM_NAME] = bit
            state["bucket_size"] = self.bucket_size_default
            # self.states[self.named_parameters[layer]]["bucket_size"] = min(self.bucket_size_default,
            #                                                                self.bucket_size_default // (2**(self.bits_default - bit)))

    def set_compression_parameter(self, state, parameter):
        if parameter is None:
            state[COMP_PARAM_NAME] = self.get_default_param()
        else:
            state[COMP_PARAM_NAME] = parameter

    def get_compression_parameter(self, state):
        if COMP_PARAM_NAME not in state:
            return None
        return state[COMP_PARAM_NAME]

    def get_default_param(self):
        return self.bits_default

    def set_default_param(self, param):
        self.bits_default = param

    def get_compression_parameter_name(self):
        return COMP_PARAM_NAME

    def _prepare_compression(self, state):
        if COMP_PARAM_NAME in state:
            bits = state[COMP_PARAM_NAME]
        else:
            bits = self.bits_default
        self.num_levels = 1 << bits
        if "bucket_size" in state:
            self.bucket_size = state["bucket_size"]
        else:
            self.bucket_size = self.get_bucket_size(bits)

    def compress_(self, tensor, state):
        a = tensor.view(-1)
        numel = a.numel()
        if COMP_PARAM_NAME in state and (state[COMP_PARAM_NAME] < 1 or state[COMP_PARAM_NAME] > 8):
            return tensor
        self._prepare_compression(state)
        if self.bucket_size == -1:
            a[:] = self.quantize_bucket(a)
        else:
            main_chunk_size = (numel // self.bucket_size) * self.bucket_size
            if main_chunk_size > 0:
                a[:main_chunk_size] = self.quantize_bucket(a[:main_chunk_size].view((-1, self.bucket_size))).view(-1)
            if self.compress_incomplete and (numel - main_chunk_size > 0):
                a[main_chunk_size:] = self.quantize_bucket(a[main_chunk_size:])
        return tensor

    @staticmethod
    def count_unique(buf):
        sum = 0
        for b in buf:
            sum += torch.unique(b).numel()
        return sum

    def compress_buffer(self, buf):
        if self.num_levels == 1 << 32:
            return buf
        numel = buf.numel()
        main_chunk_size = (numel // self.bucket_size) * self.bucket_size
        tail_chunk_size = numel - main_chunk_size
        if main_chunk_size > 0:
            r_ = buf[:main_chunk_size].view((-1, self.bucket_size))
            r_[:] = self.quantize_bucket(r_)
            # print("2d Unique: {} vs expected {}".format(self.count_unique(r_),
            #                                             (numel // self.bucket_size) * (1 << self.bits)))
        if self.compress_incomplete and tail_chunk_size > 0:
            r_ = buf[main_chunk_size:]
            r_[:] = self.quantize_bucket(r_)
            # print("Unique: {} vs expected {}".format(torch.unique(r_).numel(), (1 << self.bits)))

        return buf

    @staticmethod
    def quantize_1_dim_with_levels(buf, levels):
        rand = torch.rand(buf.size(), device=buf.device)
        res = torch.clamp(buf, levels[0], levels[-1])
        for l1, l2 in zip(levels[:-1], levels[1:]):
            l_l2 = buf.lt(l2)
            g_l1 = buf.ge(l1)
            res[l_l2 * g_l1] = l1
            b = buf + (l2 - l1) * rand
            # if exceeds after random assign l2
            g_l2 = b.ge(l2)
            res[l_l2 * g_l2] = l2
        return res

    @staticmethod
    def quantize_2_dim_with_levels(buf, levels):
        rand = torch.rand(buf.size(), device=buf.device)
        res = torch.max(buf, levels[0])
        res = torch.min(res, levels[-1])
        z = torch.zeros_like(buf)
        for l1, l2 in zip(levels[:-1], levels[1:]):
            l_l2 = buf < l2
            g_l1 = buf >= l1
            idx = l_l2 * g_l1
            # set indexed values to l1
            z[idx] = 1.0
            res[idx] = 0.0
            res.add_(z.mul_(l1))
            z.zero_()
            b = buf + (l2 - l1) * rand
            g_l2 = b >= l2
            idx = l_l2 * g_l2

            z[idx] = 1.0
            res[idx] = 0.0
            res.add_(z.mul_(l2))
            z.zero_()
        return res

    # def compute_eigen_values(self, model, criterion, dataloader):
    #     model_copy = copy.deepcopy(model)
    #     h = hessian(model_copy, criterion, dataloader=dataloader)
    #     eigenvalues, _ = h.eigenvalues()
    #     eigenvalues = eigenvalues[0]
    #     for i, p in enumerate(model.parameters()):
    #         if p not in self.states:
    #             continue
    #         self.states[p]["hess_eigen_value"] = eigenvalues[i]


class MaxMinQuantizer(Quantizer):
    def __init__(self, bits, bucket_size, enable_error_correction=False, warmup_steps=None, shifted=False):
        super().__init__(bits, bucket_size, enable_error_correction, warmup_steps)
        self.shifted=shifted

    def quantize_bucket(self, a):
        if self.num_levels == 1 << 32 or torch.isinf(a).sum() > 0:
            return a
        r = 0
        if a.dim() == 2:
            fmin = torch.min(a, dim=1)[0]
            fmax = torch.max(a, dim=1)[0]
            unit = (fmax - fmin) / (self.num_levels - 1)
            if self.shifted:
                r = torch.empty(unit.size(), device=a.device).uniform_(-0.5, 0.5)
                r = r[:, None]
            unit = unit[:, None]
            fmin = fmin[:, None]
            s = torch.Tensor([1e-11]).expand_as(unit).to(a.device)
        else:
            fmin = torch.min(a)
            fmax = torch.max(a)
            unit = (fmax - fmin) / (self.num_levels - 1)
            s = torch.Tensor([1e-11]).to(a.device)
            if self.shifted:
                r = torch.empty(1, device=a.device).uniform_(-0.5, 0.5)

        unit = torch.max(unit, s)
        r *= unit
        a += r
        a -= fmin
        a /= unit
        # a += torch.empty(a.size(), device=a.device).uniform_(0, 1)
        a += 0.5
        torch.floor_(a)
        a *= unit
        a += fmin
        a -= r
        return a


class ExponentialQuantizer(Quantizer):
    def __init__(self, bits, bucket_size, enable_error_correction=False, adjuster=None):
        super().__init__(bits, bucket_size, enable_error_correction, adjuster)
        self.num_levels = self.num_levels // 2
        # self.norm_type = float("inf")
        self.norm_type = 2
        self.levels_1dim = torch.tensor([0.5 ** i for i in range(self.num_levels, 0, -1)])
        self.levels_2dim = self.levels_1dim[:, None]

    def quantize_bucket_new(self, buf):
        sign = buf.sign()
        if buf.dim() == 2:
            vnorm = buf.norm(p=self.norm_type, dim=1)
            vnorm = vnorm[:, None]
        else:
            vnorm = torch.norm(buf, p=self.norm_type)
        if self.bits == 1:
            return sign * vnorm
        a = buf.abs() / vnorm
        if buf.dim() == 2:
            self.levels_2dim = self.levels_2dim.to(buf.device)
            res = self.quantize_2_dim_with_levels(a, self.levels_2dim)
        else:
            self.levels_1dim = self.levels_1dim.to(buf.device)
            res = self.quantize_1_dim_with_levels(a, self.levels_1dim)
        return res * sign * vnorm

    def set_compression_parameters(self, state):
        super().set_compression_parameters(state)
        bits = state[COMP_PARAM_NAME]
        self.num_levels = (1 << bits) // 2

    def quantize_bucket(self, a):
        if a.dim() == 2:
            vnorm = torch.norm(a, p=self.norm_type, dim=1)
            vnorm = vnorm[:, None]
            s = torch.Tensor([1e-11]).expand_as(vnorm).to(a.device)
        else:
            vnorm = torch.norm(a, p=self.norm_type)
            s = torch.Tensor([1e-11]).to(a.device)
        vnorm = torch.max(vnorm, s)
        sign = torch.sign(a)
        sign[sign == 0.0] = 1.0
        if self.num_levels <= 1:
            return vnorm * sign
        a = torch.abs(a / vnorm)
        logs = torch.log2(a)
        logs[logs == -float("inf")] = -32.0
        logs = logs.int()
        max_pow = torch.zeros_like(logs) + torch.max(logs).int()
        min_pow = max_pow - self.num_levels + 2
        now = torch.max(min_pow, logs).float()
        l = torch.pow(2.0, now - 1)
        r = 2 * l
        a = torch.min(r, torch.max(a, l))
        rand = torch.rand(a.size(), device=a.device)
        c = (a - l) / (r - l)
        a = l * c.le(rand).float() + r * c.gt(rand).float()
        return a * vnorm * sign


class NormUniformQuantizer(Quantizer):
    def __init__(self, bits, bucket_size, enable_error_correction=False, adjuster=None):
        super().__init__(bits, bucket_size, enable_error_correction, adjuster)
        self.num_levels = self.num_levels // 2
        self.levels_1dim = torch.tensor([i * 1.0 / (self.num_levels + 1) for i in range(1, self.num_levels + 1)])
        self.levels_2dim = self.levels_1dim[:, None]
        self.norm_type = float("inf")

    def set_compression_parameters(self, state):
        super().set_compression_parameters(state)
        bits = state[COMP_PARAM_NAME]
        self.num_levels = (1 << bits) // 2

    def quantize_bucket_new(self, buf):
        sign = buf.sign()
        if buf.dim() == 2:
            vnorm = torch.norm(buf, p=self.norm_type, dim=1)
            vnorm = vnorm[:, None]
        else:
            vnorm = torch.norm(buf, p=self.norm_type)
        if self.bits == 1:
            return sign * vnorm
        a = buf.abs() / vnorm
        if buf.dim() == 2:
            self.levels_2dim = self.levels_2dim.to(buf.device)
            res = self.quantize_2_dim_with_levels(a, self.levels_2dim)
        else:
            self.levels_1dim = self.levels_1dim.to(buf.device)
            res = self.quantize_1_dim_with_levels(a, self.levels_1dim)
        return res * sign * vnorm

    def quantize_bucket(self, a):
        if a.dim() == 2:
            vnorm = torch.norm(a, p=float("inf"), dim=1)
            vnorm = vnorm[:, None]
            s = torch.Tensor([1e-11]).expand_as(vnorm).to(a.device)
        else:
            vnorm = torch.norm(a, p=float("inf"))
            s = torch.Tensor([1e-11]).to(a.device)
        vnorm = torch.max(vnorm, s)
        sign = torch.sign(a)
        # cast sign to 1 bit
        sign[sign == 0.0] = 1.0
        if self.num_levels > 1:
            q = torch.abs(a / vnorm)
            r = torch.rand(a.shape, device=a.device)
            q.mul_((self.num_levels - 1))
            q.add_(r)
            torch.floor_(q)
            q.div_((self.num_levels - 1))
            return q * vnorm * sign
        else:
            return vnorm * sign


class QuantileQuantizer(Quantizer):
    def __init__(self, bits, bucket_size, enable_error_correction=False, adjuster=None):
        super().__init__(bits, bucket_size, enable_error_correction, adjuster)
        self.quantiles = torch.tensor([i / (self.num_levels + 1) for i in range(1, self.num_levels + 1)])

    def quantize_bucket(self, buf):
        self.quantiles = self.quantiles.to(buf.device).type(buf.dtype)
        if buf.dim() == 1:
            qs = torch.quantile(buf, self.quantiles)
            res = self.quantize_1_dim_with_levels(buf, qs)
        else:
            qs_1dim = torch.quantile(buf, self.quantiles, dim=1)
            qs_2dim = []
            for q in qs_1dim:
                qs_2dim.append(q[:, None])
            res = self.quantize_2_dim_with_levels(buf, qs_2dim)
        # if torch.unique(res).numel() <= qs.numel():
        #     print(self.num_levels, torch.unique(res).numel())
        #     raise ValueError("Num unique values are {}, quantiles: {}".format(
        #         torch.unique(res), qs))
        return res


class TernGrad(Quantizer):
    def __init__(self, bucket_size, enable_error_correction=False, adjuster=None):
        super().__init__(2, bucket_size, enable_error_correction, adjuster)
        self.clip_constant = 2.5

    def quantize_bucket(self, a):
        sign = torch.sign(a)
        sign[sign == 0.0] = 1.0
        a.abs_()
        sigma = torch.std(a) * self.clip_constant
        torch.min(a, sigma, out=a)
        if a.dim() == 2:
            vnorm = torch.norm(a, p=float("inf"), dim=1)
            vnorm = vnorm[:, None]
            s = torch.Tensor([1e-11]).expand_as(vnorm).to(a.device)
        else:
            s = torch.Tensor([1e-11]).to(a.device)
            vnorm = torch.norm(a, p=float("inf")).expand_as(s)
        torch.max(vnorm, s, out=vnorm)
        r = torch.rand(a.shape, device=a.device)
        a.div_(vnorm).add_(r)
        torch.floor_(a)
        a.mul_(sign).mul_(vnorm)
        return a


class ThreeLC(Quantizer):
    def __init__(self, bucket_size, adjuster=None):
        super(ThreeLC, self).__init__(2, bucket_size, True, adjuster)

    def quantize_bucket(self, a):
        if a.dim() == 2:
            vnorm = torch.norm(a, p=float("inf"), dim=1)
            vnorm = vnorm[:, None]
            s = torch.Tensor([1e-11]).expand_as(vnorm).to(a.device)
        else:
            s = torch.Tensor([1e-11]).to(a.device)
            vnorm = torch.norm(a, p=float("inf")).expand_as(s)
        torch.max(vnorm, s, out=vnorm)
        a.div_(vnorm)
        # do rounding
        a.add_(0.5)
        torch.floor_(a)
        return a.mul_(vnorm)


class OneBitQuantizer(Quantizer):
    def __init__(self, bucket_size, enable_error_correction=False, adjuster=None):
        super().__init__(1, bucket_size, enable_error_correction, adjuster)

    def quantize_bucket(self, a):
        if a.dim() == 2:
            vnorm = torch.norm(a, p=float("inf"), dim=0)
            vnorm = vnorm[None, :]
            s = torch.Tensor([1e-11]).expand_as(vnorm).to(a.device)
            vnorm = torch.max(vnorm, s)
        else:
            vnorm = torch.norm(a, p=float("inf"))
            vnorm = torch.max(vnorm, torch.tensor([1e-11], device=vnorm.device))
        sign = torch.sign(a)
        # cast sign to 1 bit and multiply by norm
        return sign.add_(1).div_(2).mul_(2).add_(-1).mul(vnorm)


class SanityQuantizer(Quantizer):
    def __init__(self):
        super().__init__(32, -1)

    def quantize_bucket(self, a):
        if self.num_levels == 1:
            return a
        return torch.zeros_like(a)
