from abc import abstractmethod

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
from tqdm import tqdm

from dbq import Quantizer, DBQQuantizer


def weights_init(m):
    # initialize according to DBQ
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight)
    elif isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)
    elif isinstance(m, DBQQuantizer):
        m.alpha.requires_grad = False
        m.gamma1.requires_grad = False
        m.gamma2.requires_grad = False
        m.t.requires_grad = False


def measure_size(module,
                 actual_size=False,
                 count_only_non_shrinkable=False,
                 count_only_quant_hyperparams=False,
                 count_non_dbq_modules=False):
    size = 0
    for m in module.modules():
        if isinstance(m, QuantizedModule):
            if not count_non_dbq_modules:
                size += m.module_size(
                    actual_size=actual_size,
                    count_only_non_shrinkable=count_only_non_shrinkable,
                    count_only_quant_hyperparams=count_only_quant_hyperparams,
                )
        elif not isinstance(m, Quantizer):
            if not count_only_quant_hyperparams:
                size += sum(p.numel() for p in m.parameters(recurse=False)) * 4
    return size


@torch.no_grad()
def init_quant_params(m, temp_init, init_epochs, unfreeze_quant, adapt_params):
    # initialize according to DBQ
    if isinstance(m, TernaryConv2d):
        m.use_quant = True
        w = m.weight.data
        w = w.reshape(w.shape[0], -1)
        m.quantizer.init_quant_params(w, temp_init, init_epochs, unfreeze_quant, adapt_params)
    elif isinstance(m, TernaryLinear):
        m.use_quant = True
        w = m.weight.data
        m.quantizer.init_quant_params(w, temp_init, init_epochs, unfreeze_quant, adapt_params)


def shrink_model(model: torch.nn.Module, remove_count: int, verbose: bool):
    modules = []
    for name, m in model.named_modules():
        if isinstance(m, QuantizedModule):
            modules.append((name, m))

    # Calculate score function
    scores_all = []
    initial_errors_all = []
    final_errors_all = []
    errors_all = []
    saved_bits_all = []
    n_groups = []
    n_values_before_all = []
    n_values_after_all = []
    for _, m in tqdm(modules, desc="Calculating scores..."):
        (group_scores, errors, initial_errors, final_errors,
         saved_bits, n_values_before, n_values_after) = m.calc_scores()
        scores_all += group_scores
        initial_errors_all += initial_errors
        final_errors_all += final_errors
        errors_all += errors
        saved_bits_all += saved_bits
        n_values_before_all += n_values_before
        n_values_after_all += n_values_after
        n_groups.append(len(group_scores))
    scores_all = np.array(scores_all)

    indices = np.sort(np.argsort(scores_all)[::-1][:remove_count])

    print(f"Removing {remove_count} values")

    # Remove values
    start = 0
    idx = 0
    with tqdm(total=remove_count) as pbar:
        for n_group, (name, m) in zip(n_groups, modules):
            remove_indices = []
            while idx < remove_count and 0 <= indices[idx] - start < n_group:
                group_idx = indices[idx] - start
                i = indices[idx]
                if verbose:
                    print(f"Removing {name}[{group_idx}], saved_bits {saved_bits_all[i]}, "
                          f"error {errors_all[i]:.4f} ({initial_errors_all[i]:.4f} -> {final_errors_all[i]:.4f}), "
                          f"score {scores_all[i]:.0f}, "
                          f"n_values {n_values_before_all[i]} -> {n_values_after_all[i]}")
                remove_indices.append(group_idx)
                idx += 1
                pbar.update()
            m.shrink_groups(remove_indices)
            start += n_group

    return scores_all, np.array(errors_all), np.array(saved_bits_all)


class QuantizedModule(object):
    @abstractmethod
    def module_size(self, actual_size, count_only_non_shrinkable, count_only_quant_hyperparams):
        ...


class TernaryConv2d(nn.Conv2d, QuantizedModule):
    def __init__(self, *args, num_branches=2, gen_matrix_every_step=False,
                 quantizer_type, **kwargs):
        super(TernaryConv2d, self).__init__(*args, **kwargs)
        s = self.weight.data.shape
        self.quantizer = quantizer_type(s[0], s[1] * s[2] * s[3], branches=num_branches,
                                        gen_matrix_every_step=gen_matrix_every_step)
        self.use_quant = False

    def get_weight(self):
        if self.use_quant:
            w = self.weight.reshape(self.weight.data.shape[0], -1)
            w = self.quantizer.quantize(w)
            w = w.reshape(self.weight.data.shape)
        else:
            w = self.weight
        return w

    def forward(self, x):
        assert self.bias is None
        output = self._conv_forward(x, self.get_weight(), None)
        return output

    def module_size(self, actual_size, count_only_non_shrinkable, count_only_quant_hyperparams):
        if self.use_quant:
            return self.quantizer.num_bytes(
                actual_size=actual_size,
                count_only_non_shrinkable=count_only_non_shrinkable,
                count_only_quant_hyperparams=count_only_quant_hyperparams
            )
        else:
            if count_only_quant_hyperparams:
                return 0
            return self.weight.numel() * 4

    def calc_scores(self):
        w = self.weight.reshape(self.weight.data.shape[0], -1)
        return self.quantizer.calc_scores(w)

    def shrink_groups(self, group_indices):
        w = self.weight.reshape(self.weight.data.shape[0], -1)
        return self.quantizer.shrink_groups(w, group_indices)


class TernaryLinear(nn.Linear, QuantizedModule):
    def __init__(self, *args, num_branches=2, gen_matrix_every_step=False,
                 quantizer_type, **kwargs):
        super(TernaryLinear, self).__init__(*args, **kwargs)
        s = self.weight.data.shape
        self.quantizer = quantizer_type(s[0], s[1], branches=num_branches,
                                        gen_matrix_every_step=gen_matrix_every_step)
        self.use_quant = False

    def get_weight(self):
        if self.use_quant:
            w = self.quantizer.quantize(self.weight)
        else:
            w = self.weight
        return w

    def forward(self, x):
        assert self.bias is None
        output = F.linear(x, self.get_weight())
        return output

    def module_size(self, actual_size, count_only_non_shrinkable, count_only_quant_hyperparams):
        if self.use_quant:
            return self.quantizer.num_bytes(
                actual_size=actual_size,
                count_only_non_shrinkable=count_only_non_shrinkable,
                count_only_quant_hyperparams=count_only_quant_hyperparams
            )
        else:
            if count_only_quant_hyperparams:
                return 0
            return self.weight.numel() * 4

    def calc_scores(self):
        return self.quantizer.calc_scores(self.weight)

    def shrink_groups(self, group_indices):
        return self.quantizer.shrink_groups(self.weight, group_indices)
