import torch
from loguru import logger
from .base_blockwise_quantization import BaseBlockwiseQuantization
from llmc.utils.registry_factory import ALGO_REGISTRY
from .module_utils import FakeQuantLinear


@ALGO_REGISTRY
class AdaDim(BaseBlockwiseQuantization):
    def __init__(self, model, quant_config, input, config):
        super().__init__(model, quant_config, input, config)

    def get_layer_out(self, x, layer):
        with torch.no_grad():
            org_out = layer(x)
            if isinstance(org_out, tuple):
                org_out = org_out[0]
        return org_out

    def search_dim_subset(self, layers_dict, input, idx):
        for name in layers_dict:
            layer = layers_dict[name]

            loss_dict = {}
            for dim in ["oc", "ic"]:
                loss_mean = 0

                weight = layer.weight.data.clone()

                q_weight = self.wquantizer.fake_quant_weight_dynamic(
                    weight, {"dim": dim}
                )

                for i in range(len(input)):
                    input[i] = input[i].to(layer.weight.data.device)
                    x = input[i]

                    layer.weight.data = weight
                    org_out = self.get_layer_out(x, layer)

                    layer.weight.data = q_weight
                    out = self.get_layer_out(x, layer)

                    loss = (org_out - out).float().pow(2).mean().item()
                    loss_mean += x.shape[0] * 1.0 / self.n_samples * loss

                loss_dict[dim] = loss_mean
                layer.weight.data = weight

            if loss_dict["ic"] < loss_dict["oc"]:
                layer.register_buffer("buf_qdim", torch.tensor(0))
                logger.info(f"Suggest layer {name} use per-input channel quant")
            else:
                layer.register_buffer("buf_qdim", torch.tensor(1))
                logger.info(f"Suggest layer {name} use per-output channel quant")

    def block_transform(self, block, input_feat, idx, block_kwargs):
        logger.info(f"Start transform the {idx+1}-th block")
        subsets = self.model.get_subsets_in_block(block)
        for index, subset in enumerate(subsets):
            logger.info(f"subset: {subset}")
            layers_dict = subset["layers"]
            input_name = subset["input"][0]

            self.search_dim_subset(layers_dict, input_feat[input_name], idx)

            params_dict = {}
            module = FakeQuantLinear

            params_dict["w_qdq"] = self.w_qdq
            params_dict["a_qdq"] = self.a_qdq if not self.w_only else None

            self.model.replace_module_subset(module, block, subset, idx, params_dict)

        logger.info(f"End transform the {idx+1}-th block")

    def w_qdq(self, module):
        weight = module.weight
        args = {}
        args["dim"] = "ic" if module.buf_qdim == 0 else "oc"

        weight = self.wquantizer.fake_quant_weight_dynamic(weight, args)

        return weight
