import numpy as np
import torch
import gc
import functools
import transformers
import torch.nn as nn
import math
import torch.nn.functional as F
from loguru import logger
from contextlib import nullcontext
from .base_blockwise_quantization import BaseBlockwiseQuantization
from llmc.utils.registry_factory import ALGO_REGISTRY
from .module_utils import FakeQuantLinear
from .train_utils import NativeScalerWithGradNormCount


@ALGO_REGISTRY
class FWDQ(BaseBlockwiseQuantization):
    def __init__(self, model, quant_config, input, config):
        super().__init__(model, quant_config, input, config)
        self.dev = torch.device("cuda")
        self.model_dtype = next(self.model.model.parameters()).dtype
        self.add_quant_config()
        self.layers_cache = {}
        self.collect_model_qparams()

    @torch.no_grad()
    def add_quant_config(self):
        self.prefix = self.model.block_name_prefix
        self.init_prob = self.quant_config["special"]["init_prob"]
        self.iterations = self.quant_config["special"]["iterations"]
        self.channel_wise = self.quant_config["special"]["channel_wise"]
        self.deactive_amp = False
        if self.deactive_amp:
            self.dtype = torch.float
            self.traincast = nullcontext
        else:
            self.dtype = torch.float16
            self.traincast = torch.cuda.amp.autocast

    @torch.no_grad()
    def block_transform(self, block, input_feat, idx, block_kwargs):
        logger.info(f"Start transform the {idx+1}-th block")

        subsets = self.model.get_subsets_in_block(block)

        for subset in subsets:

            self.subset_transform(subset["layers"], input_feat)
            params_dict = {}
            module = FakeQuantLinear
            params_dict["a_qdq"] = None
            params_dict["w_qdq"] = self.w_qdq

            self.model.replace_module_subset(
                module, block, subset, idx, params_dict
            )
        logger.info(f"End transform the {idx+1}-th block")


    @torch.no_grad()
    def subset_transform(self, layers_dict, input_feat):
        for name in layers_dict:
            layer = layers_dict[name]
            # self.layer_transform(layer, name)
            # self.adaround_transform(layer, name, input_feat[name])
            self.brecq_transform(layer, name, input_feat[name])
            # del input_feat[name]
            gc.collect()
            torch.cuda.empty_cache()

    @torch.no_grad()
    def layer_transform(self, layer, name):
        logger.info('Transforming Layer ' + name)
        W = layer.weight.data.clone().float()
        # initialize rounding variable
        org_w_shape = W.shape
        org_w_dtype = W.dtype
        scales = layer.buf_scales
        zeros = layer.buf_zeros
        max_int = layer.buf_max_int
        min_int = layer.buf_min_int
        W = self.wquantizer.reshape_tensor(W)
        W.div_(scales)

        rounding = (W - torch.floor(W) > 0.5).type(org_w_dtype)
        residue = 1 - 2 * (W - torch.floor(W) - 0.5).abs()
        recon_loss_prev = self.get_loss(self.layers_cache[name]["H"],
                                        layer.weight.data.float() - self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, None),
                                        self.channel_wise)
        logger.info('Initial reconstruction loss: {}'.format(recon_loss_prev.mean().item()))

        for iters in range(self.iterations):
            # sample flip variable and then perturb
            # m = torch.bernoulli(torch.ones_like(W) * self.init_prob * (self.iterations - iters) / self.iterations)
            m = torch.bernoulli(residue * self.init_prob * (self.iterations - iters) / self.iterations)
            rounding.sub_(m).abs_()

            recon_loss_new = self.get_loss(self.layers_cache[name]["H"],
                                        layer.weight.data.float() - self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, rounding),
                                        self.channel_wise)
            if self.channel_wise:
                keep = (recon_loss_new > recon_loss_prev).type(org_w_dtype)
                recon_loss_prev = recon_loss_prev * keep + recon_loss_new * (1 - keep)
                keep = keep.reshape(-1, 1)
                m = (m.reshape(org_w_shape) * keep).reshape_as(rounding)
                rounding.sub_(m).abs_()
            else:
                if recon_loss_new < recon_loss_prev:
                    recon_loss_prev = recon_loss_new
                else:
                    rounding.sub_(m).abs_()

        logger.info('After learning, loss: {}'.format(recon_loss_prev.mean().item()))

        if not torch.isnan(recon_loss_prev.sum()):
            # merge rounding into original weight
            rounding.sub_((W - torch.floor(W) > 0.5).type(org_w_dtype)).mul_(0.5 * scales)
            rounding = rounding.reshape(org_w_shape)
            layer.weight.data.add_(rounding.to(self.model_dtype))

            w2 = self.w_qdq(layer)

            logger.info('Weight difference: {}'.format(rounding.abs().sum(1).mean()))

    def adaround_transform(self, layer, name, feat_dict):
        logger.info('Transforming Layer ' + name)
        W = layer.weight.data.clone().float()
        X = torch.cat(feat_dict, dim=0).cuda().float()
        Target = F.linear(X, W)
        # # initialize rounding variable
        batch_size = 4
        org_w_shape = W.shape
        org_w_dtype = W.dtype
        scales = layer.buf_scales
        zeros = layer.buf_zeros
        max_int = layer.buf_max_int
        min_int = layer.buf_min_int
        W = self.wquantizer.reshape_tensor(W).div_(scales)
        sigmoid = RectifiedSigmoid(-0.1, 1.1)
        rounding = W - torch.floor(W)
        rounding = sigmoid.inverse(rounding)

        optimizer = torch.optim.Adam([rounding], lr=0.001)
        temp_decay = LinearTempDecay(
            self.iterations, rel_start_decay=0,
            start_t=20, end_t=2)
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt_bit, T_max=self.iterations, eta_min=0.)

        recon_loss_prev = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, None), Target)
        logger.info('Initial quantization reconstruction loss: {}'.format(recon_loss_prev.sum().item()))
        recon_loss_soft = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, sigmoid(rounding)), Target)
        logger.info('Initial soft rounding reconstruction loss: {}'.format(recon_loss_soft.sum().item()))
        loss_scaler = NativeScalerWithGradNormCount()

        with torch.enable_grad():
            rounding.requires_grad_(True)
            for iters in range(self.iterations):
                with self.traincast():
                    idx = torch.randperm(32)[:batch_size]
                    recon_loss = self.get_mse_loss(X[idx],
                                                   self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, sigmoid(rounding)),
                                                   Target[idx])
                    annealing_temp = temp_decay(iters)
                    round_loss = (1 - (2 * sigmoid(rounding) - 1).abs().pow(annealing_temp)).mean()
                    loss = recon_loss + 300 * round_loss
                    if iters % 500 == 0 or iters == self.iterations-1:
                        logger.info('iter: {}, recon loss: {}, round loss: {}, temp: {}'.format(iters, recon_loss.item(), round_loss.item(), annealing_temp))

                optimizer.zero_grad()
                _ = loss_scaler(loss, optimizer, parameters=[rounding])

        rounding = (rounding > 0).float()
        recon_loss_new = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, rounding),
                                            Target)
        logger.info('After learning loss: {}'.format(recon_loss_new.sum().item()))

        if recon_loss_new < recon_loss_prev:

            rounding.sub_((W - torch.floor(W) > 0.5).type(org_w_dtype)).mul_(0.5 * scales)
            rounding = rounding.reshape(org_w_shape)
            layer.weight.data.add_(rounding.to(self.model_dtype))

            # logger.info('Weight difference: {}'.format(rounding.abs().sum(1).mean()))
            recon_loss_new = self.get_mse_loss(X,
                                               self.w_q(self.wquantizer.reshape_tensor(layer.weight.data.float())/scales,
                                                        scales, zeros, max_int, min_int, org_w_shape, None),
                                               Target)
            logger.info('Deployment loss: {}'.format(recon_loss_new.sum().item()))

    def brecq_transform(self, layer, name, feat_dict):
        logger.info('Transforming Layer ' + name)
        W = layer.weight.data.clone().float()
        X = torch.cat(feat_dict, dim=0).cuda().float()
        Target = F.linear(X, W)
        # # initialize rounding variable
        batch_size = 4
        org_w_shape = W.shape
        org_w_dtype = W.dtype
        scales = layer.buf_scales
        zeros = layer.buf_zeros
        max_int = layer.buf_max_int
        min_int = layer.buf_min_int
        W = self.wquantizer.reshape_tensor(W).div_(scales)
        sigmoid = RectifiedSigmoid(-0.1, 1.1)
        rounding = W - torch.floor(W)
        rounding = sigmoid.inverse(rounding)

        recon_loss_prev = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, None), Target)
        logger.info('Initial quantization reconstruction loss: {}'.format(recon_loss_prev.sum().item()))
        recon_loss_soft = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, sigmoid(rounding)), Target)
        logger.info('Initial soft rounding reconstruction loss: {}'.format(recon_loss_soft.sum().item()))
        loss_scaler = NativeScalerWithGradNormCount()

        # values = [0.7, 0.55, 0.4, 0.35, 0.27, 0.22, 0.17, 0.12, 0.08, 0.04]
        values = [0.7, 0.6, 0.5, 0.45, 0.4, 0.36, 0.32, 0.29, 0.26, 0.23, 0.2, 0.17, 0.15, 0.12, 0.10, 0.08, 0.06, 0.04, 0.03, 0.02]
        for i in range(20):
            # generate mask
            mask = (sigmoid(rounding) - 0.5).abs()
            # value = torch.quantile(mask, q=0.9 - 0.1 * i)
            value = np.quantile(mask.cpu().numpy(), q=values[i])
            # value = values[i]
            mask = (mask > value).float()
            percentage = mask.sum() / mask.numel()

            optimizer = torch.optim.Adam([rounding], lr=self.init_prob)
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations, eta_min=0.)

            with torch.enable_grad():
                rounding.requires_grad_(True)
                for iters in range(self.iterations):
                    with self.traincast():
                        idx = torch.randperm(128)[:batch_size]
                        loss = self.get_mse_loss(X[idx],
                                                 self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, sigmoid(rounding) * (1 - mask) + sigmoid(rounding).round() * mask),
                                                 Target[idx])

                    optimizer.zero_grad()
                    _ = loss_scaler(loss, optimizer, parameters=[rounding])
                    lr_scheduler.step()

                logger.info('Iteration: {}, recon loss: {}, hard rounding percentage: {}'.format(i, loss.item(), percentage))
                rounding.requires_grad_(False)

        rounding = (rounding > 0).float()
        recon_loss_new = self.get_mse_loss(X, self.w_q(W, scales, zeros, max_int, min_int, org_w_shape, rounding),
                                            Target)
        logger.info('After learning loss: {}'.format(recon_loss_new.sum().item()))

        if recon_loss_new < recon_loss_prev:

            rounding.sub_((W - torch.floor(W) > 0.5).type(org_w_dtype)).mul_(0.5 * scales)
            rounding = rounding.reshape(org_w_shape)
            layer.weight.data.add_(rounding.to(self.model_dtype))

            # logger.info('Weight difference: {}'.format(rounding.abs().sum(1).mean()))
            recon_loss_new = self.get_mse_loss(X,
                                               self.w_q(self.wquantizer.reshape_tensor(layer.weight.data.float())/scales,
                                                        scales, zeros, max_int, min_int, org_w_shape, None),
                                               Target)
            logger.info('Deployment loss: {}'.format(recon_loss_new.sum().item()))

    def get_loss(self, H, DeltaW, channel_wise):
        out = ((DeltaW @ H) * DeltaW).sum(dim=1)
        return out if channel_wise else out.mean()

    def get_mse_loss(self, x, w, target):
        out = F.linear(x, w, bias=None)
        loss = (out-target).pow(2).sum(-1).mean()
        return loss

    def w_q(self, weight, scales, zeros, max_int, min_int, org_weight_shape, rounding=None, ste=False):
        if rounding is None and not ste:
            weight = torch.round(weight)
        elif rounding is None and ste:
            weight = (torch.round(weight) - weight).detach() + weight
        else:
            weight = torch.floor(weight) + rounding
        weight = (torch.clamp(weight + zeros, min_int, max_int) - zeros) * scales
        return weight.reshape(org_weight_shape)

    # @torch.no_grad()
    # def cache_input_hook(self, m, inp, out, name, feat_dict):
    #     self.add_batch(self.named_layers[name], name, inp[0].data, out.data)
    #
    # @torch.no_grad()
    # def add_batch(self, layer, name, inp, out):
    #     if len(inp.shape) == 2:
    #         inp = inp.unsqueeze(0)
    #     tmp = inp.shape[0]
    #     if isinstance(layer, (FakeQuantLinear, nn.Linear, transformers.Conv1D)):
    #         if len(inp.shape) == 3:
    #             inp = inp.reshape((-1, inp.shape[-1]))
    #         inp = inp.t()
    #     if isinstance(layer, nn.Conv2d):
    #         unfold = nn.Unfold(
    #             layer.kernel_size,
    #             dilation=layer.dilation,
    #             padding=layer.padding,
    #             stride=layer.stride,
    #         )
    #         inp = unfold(inp)
    #         inp = inp.permute([1, 0, 2])
    #         inp = inp.flatten(1)
    #
    #     self.layers_cache[name]["H"] *= self.layers_cache[name]["nsamples"] / (
    #             self.layers_cache[name]["nsamples"] + tmp
    #     )
    #     self.layers_cache[name]["nsamples"] += tmp
    #     inp = math.sqrt(2 / self.layers_cache[name]["nsamples"]) * inp.float()
    #     self.layers_cache[name]["H"] += inp.matmul(inp.t())
    #
    # @torch.no_grad()
    # def layer_init(self, layer, name):
    #     W = layer.weight.data.clone()
    #     if isinstance(layer, nn.Conv2d):
    #         W = W.flatten(1)
    #     if isinstance(layer, transformers.Conv1D):
    #         W = W.t()
    #     self.layers_cache[name]["H"] = torch.zeros(
    #         (W.shape[1], W.shape[1]), device=self.dev
    #     )
    #     self.layers_cache[name]["nsamples"] = 0
    #     self.layers_cache[name]["columns"] = W.shape[1]
    #
    # @torch.no_grad()
    # def subset_init(self, subset):
    #     self.named_layers = subset["layers"]
    #     for name in self.named_layers:
    #         self.layers_cache[name] = {}
    #         self.layer_init(self.named_layers[name], name)
    #
    # @torch.no_grad()
    # def block_init(self, block):
    #     self.named_layers = self.model.get_block_linears(block)
    #     for name in self.named_layers:
    #         self.layers_cache[name] = {}
    #         self.layer_init(self.named_layers[name], name)

    @torch.no_grad()
    def w_qdq(self, module):
        weight = module.weight

        args = {}
        args["scales"] = module.buf_scales
        if hasattr(module, "buf_zeros"):
            args["zeros"] = module.buf_zeros
        else:
            args["zeros"] = None
        args["max_int"] = module.buf_max_int
        args["min_int"] = module.buf_min_int

        weight = self.wquantizer.fake_quant_weight_static(weight, args).to(
            self.model_dtype
        )

        return weight

    @torch.no_grad()
    def collect_model_qparams(self):
        for i in range(len(self.blocks)):
            named_linears = self.model.get_block_linears(self.blocks[i])
            for n, m in named_linears.items():
                m.cuda()
                m = m.float()
                (
                    tensor,
                    scales,
                    zeros,
                    max_int,
                    min_int,
                ) = self.wquantizer.get_tensor_qparams(m.weight.data)
                m = m.to(self.model_dtype)
                m.cpu()
                m.register_buffer("buf_scales", scales)
                m.register_buffer("buf_zeros", zeros)
                m.register_buffer("buf_max_int", torch.tensor(max_int))
                m.register_buffer("buf_min_int", torch.tensor(min_int))


class RectifiedSigmoid(nn.Module):
    def __init__(self, gamma, zeta):
        super(RectifiedSigmoid, self).__init__()
        self.gamma = gamma
        self.zeta = zeta

    def forward(self, x):
        return torch.clamp(torch.sigmoid(x)*(self.zeta-self.gamma) + self.gamma, 0, 1)

    def inverse(self, y):
        """return x that satisfies y = RectifiedSigmoid(x)"""
        return -torch.log((self.zeta-self.gamma)/(y-self.gamma) - 1)


class LinearTempDecay:
    def __init__(self, iter_max, rel_start_decay, start_t, end_t):
        self.t_max = iter_max
        self.start_decay = rel_start_decay * iter_max
        self.start_b = start_t
        self.end_b = end_t

    def __call__(self, cur_iter):
        if cur_iter < self.start_decay:
            return self.start_b
        else:
            rel_t = (cur_iter-self.start_decay) / (self.t_max-self.start_decay)
            return self.end_b + (self.start_b-self.end_b)*max(0.0, 1 - rel_t)