import math
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from scipy.cluster.vq import kmeans2
import numpy as np
import utils


def update_temp(m, new_temp):
    if isinstance(m, DBQQuantizer):
        m.T.data.copy_(torch.tensor(new_temp))


def update_quant_matrix(m):
    if isinstance(m, DBQQuantizer):
        m.quant_matrix.copy_(m.gen_quant_matrix())


def get_num_groups(net):
    n_groups = 0
    for m in net.modules():
        if isinstance(m, Quantizer):
            n_groups += m.n_groups
    return n_groups


def rmse(a, b):
    return torch.sqrt(torch.sum((a.flatten() - b.flatten()) ** 2)).item()


class Quantizer(ABC, nn.Module):
    pass


class DBQQuantizer(Quantizer):
    def __init__(self, n_groups, group_size, branches=2, gen_matrix_every_step=False):
        super(DBQQuantizer, self).__init__()
        self.n_groups, self.group_size, self.max_branches = n_groups, group_size, branches
        self.gen_matrix_every_step = gen_matrix_every_step

        if branches == 1:  # 1.58 bit
            alpha_init = [[1.0]]
        elif branches == 2:  # 3.16 bit
            alpha_init = [[1.5], [1.0]]
        elif branches == 3:  # 4.75 bit
            alpha_init = [[1.5 ** 2], [1.5], [1.0]]
        elif branches == 4:  # 6.34 bits
            alpha_init = [[1.5 ** 3], [1.5 ** 2], [1.5], [1.0]]
        else:
            assert False

        # Quantization parameters
        self.alpha = nn.Parameter(torch.ones([branches, self.n_groups]) * torch.tensor(alpha_init))  # coeffs
        self.t = nn.Parameter(torch.zeros([3**branches - 1, self.n_groups]))  # thresholds
        self.gamma1 = nn.Parameter(torch.zeros([self.n_groups]))  # normalization factors
        self.gamma2 = nn.Parameter(torch.zeros([self.n_groups]))

        # Temperature parameter
        self.register_buffer("T", torch.tensor(0.0))

        quant_matrix = torch.stack(
            torch.meshgrid(*[torch.tensor([-1, 0, 1])]*branches),
            dim=-1
        ).reshape(-1, 1, branches).expand(3**branches, self.n_groups, branches).clone()
        self.register_buffer("quant_matrix", quant_matrix)  # (3^nb, NG, nb)
        self.quant_matrix.copy_(self.gen_quant_matrix())

    def extra_repr(self):
        return f'n_groups={self.n_groups}, group_size={self.group_size}, ' \
               f'gen_every_step={self.gen_matrix_every_step}'

    def num_bytes(self, actual_size, count_only_non_shrinkable, count_only_quant_hyperparams):
        quant_param_bytes = (
            self.max_branches * self.group_size  # alphas
            + self.gamma1.data.numel()
            + self.gamma2.data.numel()
        ) * 4
        threshold_bytes = (self.n_values() - 1).sum().item() * 4
        combinations_bytes = math.ceil(3 ** self.max_branches / 8) * self.group_size
        theoretical_bits = torch.log2(self.n_values())
        if actual_size:
            theoretical_bits = torch.ceil(theoretical_bits)
        weight_bytes = self.group_size * theoretical_bits.sum().item() / 8
        if count_only_non_shrinkable:
            return quant_param_bytes
        if count_only_quant_hyperparams:
            return quant_param_bytes + threshold_bytes + combinations_bytes
        return quant_param_bytes + threshold_bytes + combinations_bytes + weight_bytes

    @torch.no_grad()
    def gen_quant_matrix(self):
        alpha = self.alpha.data  # (nb, NG)
        quants = self.quant_matrix  # (3^nb, NG, nb)
        values = torch.einsum('ngb,bg->ng', quants.float(), alpha)  # (3^nb, NG)
        sorted_indices = torch.argsort(values, dim=0)  # (3^nb, NG)
        result = quants.gather(0, sorted_indices.unsqueeze(2).expand(
            3**self.max_branches, self.n_groups, self.max_branches
        ))  # (3^nb, NG, nb)
        return result  # (3^nb, NG, nb)

    # Normalized weights
    def gw(self, w):
        # w = (n_groups, group_size)
        g1 = self.gamma1.unsqueeze(1)  # (ng, 1)
        t = self.t.unsqueeze(2)  # (3^nb-1, ng, 1)
        gw = ((g1 * w).unsqueeze(0) - t)  # (3^nb-1, ng, n)
        return gw

    def quantize(self, w):
        g2 = self.gamma2.unsqueeze(1)
        gw = self.gw(w)  # (3^nb-1, ng, n)
        if self.training:
            f = (lambda x: torch.sigmoid(x * self.T))
        else:
            f = utils.step
        fgw = f(gw)  # (3^nb-1, ng, n)

        if self.gen_matrix_every_step:
            qm = self.gen_quant_matrix()
        else:
            qm = self.quant_matrix
        diff = (qm[1:] - qm[:-1]).float()  # (3^nb-1, ng, nb)

        # (nb, ng) x (3^nb-1, ng, nb) x (N, ng, 3^nb-1) -> (ng, N)
        cfgw = torch.einsum('bg,tgb,tgn->gn', self.alpha, diff, fgw)
        begin = torch.einsum('bg,gb->g', self.alpha, qm[0].float()).unsqueeze(1)
        q = g2 * (cfgw + begin)

        return q

    def section_index(self, w):
        gw = self.gw(w)  # (8, n_groups, group_size)
        section = utils.step(gw).sum(0).long()  # (n_groups, group_size), values 0 thru 8
        return section

    def quantized_sep(self, w):
        section = self.section_index(w)  # (n_groups, group_size), values 0 thru 8
        expand_section = section.unsqueeze(0).expand(self.max_branches, self.n_groups, self.group_size)
        q = self.gen_quant_matrix().permute(2, 1, 0).gather(2, expand_section)  # (nb, n_groups, group_size)
        return q

    @torch.no_grad()
    def normalize_alpha_signs(self):
        alpha_sign = torch.sign(self.alpha)  # (nb, ng)
        self.alpha.copy_(self.alpha * alpha_sign)
        self.quant_matrix.copy_(self.quant_matrix * alpha_sign.unsqueeze(0).permute(0, 2, 1))

    @torch.no_grad()
    def init_quant_params(self, weight, temp_init, init_epochs, unfreeze_quant, adapt_params):
        # initialize according to DBQ
        self.alpha.requires_grad = unfreeze_quant
        self.gamma1.requires_grad = unfreeze_quant
        self.gamma2.requires_grad = unfreeze_quant
        self.t.requires_grad = unfreeze_quant

        self.gamma2.data.copy_(torch.amax(torch.abs(weight), dim=1))
        self.gamma1.data.copy_(1. / self.gamma2.data)

        update_temp(self, temp_init)

        if not adapt_params:
            return

        midpoints = []
        gw = self.gamma1.unsqueeze(1) * weight
        for i in range(gw.shape[0]):
            values = gw[i].detach().unsqueeze(1).cpu().numpy()
            centroids, _ = kmeans2(values, 3 ** self.max_branches)
            assert (centroids.shape[0] == 3 ** self.max_branches)
            centroids = np.sort(centroids.squeeze(1))
            midpoints.append(torch.tensor((centroids[:-1] + centroids[1:]) * .5))
        self.t.data.copy_(torch.stack(midpoints, 1).cuda())

        print(self)
        self.eval()
        best_err = rmse(weight, self.quantize(weight))
        best_alpha = self.alpha.data.clone()
        best_t = self.t.data.clone()
        print(f"[0] RMSE = {best_err:.6f}")
        for it in range(1, init_epochs):
            # Update alpha
            B = (self.quantized_sep(weight) * self.alpha.unsqueeze(2)).transpose(0, 1)
            BBt = torch.bmm(B, B.transpose(1, 2))  # (Nout, nb, nb)
            try:
                if B.shape[2] < 1024:
                    BBt_inv_B = torch.linalg.solve(BBt, B)  # (Nout, nb, N)
                else:
                    # Workaround for pytorch bug
                    BBt_inv_B = torch.inverse(BBt) @ B
            except RuntimeError as e:
                print(e)
                break
            new_alpha = torch.bmm(BBt_inv_B, gw.unsqueeze(2)).squeeze(2)  # (Nout, nb)
            self.alpha.data.copy_(new_alpha.transpose(0, 1))
            # update t
            self.quant_matrix.data.copy_(self.gen_quant_matrix())
            # (ng, nb) x (3^nb, NG, nb) -> (ng, 3^nb)
            q = torch.einsum('gb,ngb->gn', new_alpha, self.quant_matrix.float())
            q, _ = torch.sort(q, dim=1)
            new_t = ((q[:, :-1] + q[:, 1:]) * .5).t()  # (8, Nout)
            self.t.data.copy_(new_t)
            err = rmse(weight, self.quantize(weight))
            print(f"[{it}] RMSE = {err:.6f}")
            if err < best_err:
                best_err = err
                best_alpha.copy_(self.alpha.data)
                best_t.copy_(self.t.data)
        self.alpha.copy_(best_alpha)
        self.t.copy_(best_t)
        self.quant_matrix.data.copy_(self.gen_quant_matrix())
        self.normalize_alpha_signs()
        print(f"Final RMSE = {rmse(weight, self.quantize(weight)):.6f}")

    @torch.no_grad()
    def n_values(self):
        ids = utils.ternary_to_id(self.quant_matrix.transpose(0, 1))  # (ng, 3^nb)
        n_values = torch.not_equal(utils.bincount(ids, 3 ** self.max_branches, 1), 0).sum(1)
        return n_values  # (ng,)

    def num_branches(self):
        return torch.tensor([self.max_branches] * self.n_groups)

    def error_for_group(
        self,
        gw: torch.Tensor,  # (n)
        section: torch.LongTensor,  # index (n) -> [0, 3^nb)
        qm: torch.LongTensor,  # (3^nb, nb)
        alpha: torch.Tensor,  # (nb)
        num_iter: int
    ):
        device = qm.device

        def calc_error_sum(qm_cands: torch.LongTensor):
            qgw = torch.einsum(
                'nbq,b->nq',
                qm_cands[section].float(),  # (n, nb, |Q|)
                alpha  # (nb)
            )  # (n, |Q|)

            abs_error = torch.abs(gw.unsqueeze(1) - qgw)  # (n, |Q|)
            rel_error = torch.div(abs_error, torch.abs(qgw) + 1e-6)  # (n, |Q|)
            min_error = torch.minimum(abs_error, rel_error)  # (n, |Q|)
            return min_error.sum(0)  # (|Q|)

        initial_error = calc_error_sum(qm.unsqueeze(2)).item()
        final_error = initial_error

        for _ in range(int(num_iter)):
            qm_ids = utils.ternary_to_id(qm)  # (3^nb)

            # (|Q|), index (3^nb) -> [0, |Q|)
            unique_ids, orig_to_q_perm = torch.unique(qm_ids, return_inverse=True)
            NQ = unique_ids.shape[0]
            q_to_orig_perm = utils.invert_permutation(orig_to_q_perm, NQ)  # index (|Q|) -> [0, 3^nb)
            unique_qm = qm[q_to_orig_perm]  # tern(|Q|, nb)
            unique_values = torch.einsum('qb,b->q', unique_qm.float(), alpha)  # float(|Q|)

            unique_qm_mat = unique_qm.unsqueeze(0).expand(NQ, NQ, self.max_branches)  # tern(|Q|, |Q|, nb)
            mask = torch.bitwise_not(torch.eye(NQ, device=device, dtype=torch.bool)).unsqueeze(2)
            cand_qm_mat = torch.masked_select(unique_qm_mat, mask).reshape(NQ, NQ - 1, self.max_branches)
            cand_values = torch.einsum('qrb,b->qr', cand_qm_mat.float(), alpha)  # float(|Q|,|Q|-1)

            # Difference between current quantized values and the values after one removal
            diff = torch.abs(unique_values.unsqueeze(1) - cand_values)  # float(|Q|, |Q|-1)

            # which index in cand_values is the closest to the original value
            replacement_indices = torch.argmin(diff, dim=1)  # index (|Q|) -> [0, |Q|-1)
            replacement_indices[replacement_indices >= torch.arange(NQ, device=device)] += 1  # -> [0, |Q|)

            next_qm_cands = torch.where(
                torch.eq(qm_ids.unsqueeze(1), qm_ids[q_to_orig_perm].unsqueeze(0)).unsqueeze(1),
                qm[q_to_orig_perm[replacement_indices[orig_to_q_perm]]].unsqueeze(2),
                qm.unsqueeze(2)
            )
            error_sum = calc_error_sum(next_qm_cands)  # (|Q|)

            # Find value with the least error when removed
            least_error_index = torch.argmin(error_sum)  # index () -> [0, |Q|)
            final_error = error_sum[least_error_index].item()

            # Update qm_group
            qm.copy_(next_qm_cands[:, :, least_error_index])

        return initial_error, final_error

    @torch.no_grad()
    def calc_scores(self, w):
        n_values = self.n_values()  # (ng,)

        current_bits = torch.ceil(torch.log2(n_values))
        smaller_bits = current_bits - 1
        saved_bits = (current_bits - smaller_bits) * self.group_size  # (ng,)
        saved_bits[n_values <= 8] = 0

        gw = self.gamma1.unsqueeze(1) * w  # (ng, n)
        section_indices = self.section_index(w)  # index (ng, n) -> [0, 3^nb)

        errors = []
        initial_errors = []
        final_errors = []
        scores = []
        n_values_before = []
        n_values_after = []
        for group_idx in range(self.n_groups):
            n = n_values[group_idx].item()
            n_iter = n - 2 ** int(math.ceil(math.log2(n)) - 1)
            initial_error, final_error = self.error_for_group(
                gw[group_idx],
                section_indices[group_idx],
                self.quant_matrix[:, group_idx].clone(),
                self.alpha[:, group_idx],
                n_iter
            )
            initial_errors.append(initial_error)
            final_errors.append(final_error)
            error_diff = max(final_error - initial_error, 0)
            errors.append(error_diff)
            scores.append(saved_bits[group_idx].item() / (error_diff + 1e-6))
            n_values_before.append(n)
            n_values_after.append(n - n_iter)

        return (
            scores, errors, initial_errors, final_errors,
            saved_bits.cpu().numpy().tolist(),
            n_values_before, n_values_after
        )

    def shrink_groups(self, w, indices):
        n_values = self.n_values()  # (ng,)

        gw = self.gamma1.unsqueeze(1) * w  # (ng, n)
        section_indices = self.section_index(w)  # index (ng, n) -> [0, 3^nb)

        for group_idx in indices:
            n = n_values[group_idx].item()
            n_iter = n - 2 ** int(math.ceil(math.log2(n)) - 1)
            self.error_for_group(
                gw[group_idx],
                section_indices[group_idx],
                self.quant_matrix[:, group_idx],
                self.alpha[:, group_idx],
                n_iter
            )


def total_n_values(module):
    result = 0
    for m in module.modules():
        if isinstance(m, Quantizer):
            result += m.n_values().sum().item()
    return result


def total_quant_groups(module):
    result = 0
    for m in module.modules():
        if isinstance(m, Quantizer):
            result += m.n_groups
    return result
