import random

import torch
import torch.nn as nn
import functools
import gc
import pdb
import math
import os
from random import sample
import numpy as np
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm

from math import inf
from loguru import logger
from tqdm import tqdm
from contextlib import nullcontext

from .base_blockwise_quantization import BaseBlockwiseQuantization
from .module_utils import (
    FakeQuantLinear,
    LlmcLayerNorm,
    LlmcLlamaRMSNorm,
    LlmcMistralRMSNorm,
    LlmcQwen2RMSNorm
)
from .train_utils import NativeScalerWithGradNormCount, TruncateFunction, LossFunction
from llmc.utils.registry_factory import ALGO_REGISTRY


@ALGO_REGISTRY
class TesseraQ(BaseBlockwiseQuantization):
    def __init__(self, model, quant_config, input, config):
        super().__init__(model, quant_config, input, config)
        self.add_quant_config()

        model_type = self.config["model"]["type"]
        if (
                model_type not in ["Llama", "Opt", "Falcon", "Mistral", "Qwen2"]
                and self.let
        ):
            raise ValueError("Only support for opt/llama/Llama-2/falcon/Mistral now")

        self.attention_mask = self.input["kwargs"][0].get("attention_mask")
        self.position_ids = (
            self.input["kwargs"][0].get("position_ids")
            if model_type in ["Llama", "Mistral", "Qwen2"]
            else None
        )

        if self.deactive_amp:
            self.batch_mask = self._repeat_attention_mask()
        else:
            self.batch_mask = (
                self._repeat_attention_mask().float()
                if self.attention_mask is not None
                else None
            )

        self.dev = torch.device("cuda")
        self.model_dtype = next(self.model.model.parameters()).dtype

        self.sigmoid = RectifiedSigmoid(-0.1, 1.1)
        self.loss_now = 1.

    def _repeat_attention_mask(self):
        if self.attention_mask is not None:
            return self.attention_mask.repeat(
                self.input["data"][0].shape[0], 1, 1, 1
            ).cuda()
        return None

    def add_quant_config(self):
        self.prefix = self.model.block_name_prefix
        self.loss_func = lambda x, y: (x-y).pow(2).sum(-1).mean()
        self.deactive_amp = self.quant_config["special"]["deactive_amp"]
        self.wd = self.quant_config["special"]["wd"]
        self.lr = self.quant_config['special']['lr']
        self.iterations = self.quant_config['special']['iterations']
        self.batch_size = self.quant_config['special']['batch_size']
        self.optimize_scale = self.quant_config["special"]["optimize_scale"]
        self.transform_algo = self.quant_config['special']['transform_algo']
        if 'load_only' in self.quant_config['special'].keys():
            self.load_only = self.quant_config['special']['load_only']
        else:
            self.load_only = False
        self.scale_lr = self.quant_config['special']['scale_lr'] if 'scale_lr' in self.quant_config['special'] else None

        if self.deactive_amp:
            self.dtype = torch.float
            self.traincast = nullcontext
        else:
            self.dtype = torch.float16
            self.traincast = torch.cuda.amp.autocast

        self.aug_loss = self.quant_config["special"]["aug_loss"]

        if self.transform_algo == 'None':
            self.act_scales = None
        if 'AWQ' in self.transform_algo:
            assert "scale_path" in self.quant_config["special"]
            self.scale_path = self.quant_config["special"]["scale_path"]
            self.act_scales = torch.load(os.path.join(self.scale_path, 'scales.pth'))
            for k in self.act_scales:
                self.act_scales[k] = self.act_scales[k].to(torch.float32)
        if 'OMNIQ' in self.transform_algo:
            self.wquantizer.calib_algo = 'learnable'
            self.clip_path = self.quant_config["special"]["clip_path"]
            self.weight_clips = torch.load(os.path.join(self.clip_path, 'clips.pth'))

    def block_forward(self, block, input_data=None):
        output = []

        if input_data is None:
            input_data = self.input["data"]

        for i in range(len(input_data)):
            input_data[i] = input_data[i].to(device=next(block.parameters()).device)
            if (
                "attention_mask" in self.input["kwargs"][i]
                and self.input["kwargs"][i]["attention_mask"] is not None
            ):
                self.input["kwargs"][i]["attention_mask"] = self.input["kwargs"][i][
                    "attention_mask"
                ].cuda()
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    out = block(input_data[i], **self.input["kwargs"][i])[0]
                    output.append(out)
        return output

    def get_original_out(self, block, idx):
        if idx == 0:
            self.ori_out = self.block_forward(block)
            if self.aug_loss:
                self.ori_out2 = self.ori_out
        else:
            self.ori_out = self.block_forward(block, self.ori_out)
            if self.aug_loss:
                self.ori_out2 = self.block_forward(block)

    @torch.no_grad()
    def block_transform(self, block, input_feat, idx, block_kwargs):

        logger.info(f"Start transform the {idx+1}-th block")

        with torch.no_grad():
            block.float()

        for i in range(len(self.input["data"])):
            self.input["data"][i] = self.input["data"][i].to(self.dtype)
        self.get_original_out(block, idx)                                 # collect block output
        self.load_transform(block, input_feat, idx, block_kwargs)         # load previous transform like AWQ, OmniQuant
        self.collect_block_qparams(block)  # collect quant range after transformation
        if not self.load_only:
            self.register_brecq_parameters(block, idx)
            self.brecq_train(block, idx)
            self.merge_breceq_parameters_and_clear_tmp(block)
        self.set_rounding_opt_mode(block, on=False)

        logger.info(f"End transform the {idx+1}-th block")

    def brecq_train(self, block, idx):

        self.set_dynamic_tmp_quant(block, on=True)
        for n, p in block.named_parameters():
            p.requires_grad = False

        thresholds = [0.8, 0.65, 0.5, 0.43, 0.38, 0.34, 0.3, 0.27, 0.24, 0.21, 0.18, 0.15, 0.12, 0.10, 0.08,
                      0.06, 0.04, 0.02, 0.01, 0.005]

        self.input["data"] = torch.cat(self.input["data"], dim=0)
        self.ori_out = torch.cat(self.ori_out, dim=0)

        with torch.no_grad():
            # evaluate loss before reconstruction
            loss_prev = self.get_loss(block, self.input["data"][:4], self.ori_out[:4])
            logger.info('Before BRECQ, the reconstruction loss: {}'.format(loss_prev.item()))

        for i in range(20):
            self.set_rounding_opt_mode(block, on=True)
            self.update_mask(block, quantile_threshold=thresholds[i])

            params_r, params_s = self.get_rounding_parameters(block)
            if self.optimize_scale:
                optimizer = torch.optim.Adam([
                        {"params": params_r, "lr": self.lr},
                        {"params": params_s, "lr": self.scale_lr or self.lr, "weight_decay": 1e-4},
                    ], lr=self.lr)
            else:
                optimizer = torch.optim.Adam(params_r, self.lr)

            loss_scaler = NativeScalerWithGradNormCount()

            with torch.enable_grad():
                for p in params_r+params_s:
                    p.requires_grad = True

                for iters in range(self.iterations):

                    indices = torch.randperm(self.config['calib']['n_samples'])[:self.batch_size]

                    with self.traincast():
                        target2 = self.ori_out2[indices] if self.aug_loss else None
                        loss = self.get_loss(block, self.input["data"][indices], self.ori_out[indices], target2)

                    if not math.isfinite(loss.item()):
                        logger.info("Loss is NAN, stopping training")
                        pdb.set_trace()

                    optimizer.zero_grad()

                    norm = loss_scaler(loss, optimizer, parameters=params_r+params_s)

                logger.info(f"block {idx} iter {i+1} loss:{loss.item():5f} norm:{norm.item():4f} HR progress:{(1-thresholds[i])*100:1f}% ")
                for p in params_r+params_s:
                    p.requires_grad = False

            del optimizer

        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                # set to hard masking
                m.rounding = 100 * m.rounding.sign()

        with torch.no_grad():
            loss_now = self.get_loss(block, self.input["data"][:4], self.ori_out[:4])
            self.low_now = loss_now.item()
            logger.info('After BRECQ, the reconstruction loss: {}'.format(loss_now.item()))

        self.input["data"] = list(torch.split(self.input["data"], split_size_or_sections=1, dim=0))
        self.ori_out = list(torch.split(self.ori_out, split_size_or_sections=1, dim=0))

    def get_loss(self, block, x, target, target2=None):
        if self.position_ids is not None:
            quant_out = block(x, attention_mask=self.batch_mask, position_ids=self.position_ids)[0]
        else:
            quant_out = block(x, attention_mask=self.batch_mask)[0]

        loss = self.loss_func(target, quant_out)
        if target2 is not None:
            loss = (loss + self.loss_func(target2, quant_out)) / 2
        return loss

    def register_brecq_parameters(self, block, idx):
        params_dict = {}
        module = FakeQuantLinear
        params_dict["a_qdq"] = self.a_qdq if not self.w_only else None
        params_dict["w_qdq"] = self.w_qdq
        self.model.replace_module_block(module, block, idx, params_dict)
        self.register_rounding_parameters(block)

    def register_rounding_parameters(self, block):
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                rounding = m.weight.data.clone()
                scales = m.buf_scales
                rounding = self.wquantizer.reshape_tensor(rounding).div(scales)
                rounding = rounding - torch.floor(rounding)
                rounding = self.sigmoid.inverse(rounding)

                m.rounding = rounding

                if self.optimize_scale:
                    m.output_scale_factor = torch.zeros_like(scales)

    @torch.no_grad()
    def update_mask(self, block, quantile_threshold):
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                score = (self.sigmoid(m.rounding) - 0.5).abs().cpu()
                value = np.quantile(score.numpy(), q=quantile_threshold)
                m.rounding[self.sigmoid(m.rounding) > (value + 0.5)] = float('inf')
                m.rounding[self.sigmoid(m.rounding) < (0.5 - value)] = -float('inf')
                del score

    def set_rounding_opt_mode(self, block, on=True):
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                m.rounding_opt = on

    def set_dynamic_tmp_quant(self, block, on=True):
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                m.dynamic_quant_tmp_weight = on

    def get_rounding_parameters(self, block):
        params_r = []
        params_s = []
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                params_r += [m.rounding]
                if self.optimize_scale:
                    params_s += [m.output_scale_factor]
        return params_r, params_s

    def merge_breceq_parameters_and_clear_tmp(self, block):
        for n, m in block.named_modules():
            if isinstance(m, FakeQuantLinear):
                m.rounding = (m.rounding > 0).float()
                w_shape = m.weight.shape
                W = self.wquantizer.reshape_tensor(m.weight.data) / m.buf_scales
                m.rounding = m.rounding - (W - torch.floor(W) > 0.5).float()
                m.rounding *= 0.5 * m.buf_scales
                m.rounding = self.wquantizer.restore_tensor(m.rounding, w_shape)
                m.weight.data.add_(m.rounding.to(self.model_dtype))

                delattr(m, 'rounding')
                delattr(m, 'tmp_weight')
                delattr(m, 'tmp_bias')
                m.dynamic_quant_weight = False
                m.dynamic_quant_tmp_weight = False

                gc.collect()
                torch.cuda.empty_cache()

    @torch.no_grad()
    def load_transform(self, block, input_feat, idx, block_kwargs):
        if 'AWQ' in self.transform_algo :
            logger.info('loading scales...')
            subsets = self.model.get_subsets_in_block(block)
            for index, subset in enumerate(subsets):
                prev_op = subset["prev_op"]
                layers_dict = subset["layers"]
                layers = list(layers_dict.values())

                if (
                    isinstance(prev_op[0], (nn.Linear, FakeQuantLinear))
                    and prev_op[0].out_features != layers[0].in_features * 3
                    and prev_op[0].out_features != layers[0].in_features
                ):
                    logger.info("Cannot apply scale. Do not transform this subset.")
                    continue

                for n in layers_dict:
                    layer_name = f"{self.model.block_name_prefix}.{idx}.{n}"
                scale = self.act_scales[layer_name]
                self.apply_scale(scale, prev_op, layers)
                self.update_input_feat(scale, input_feat, layers_dict)

            if self.transform_algo == 'AWQ' and self.wquantizer.bit < 4:
                self.auto_clip(block, idx, input_feat, n_sample_token=512, clip_qk=True)

        if 'OMNIQ' in self.transform_algo:
            # load clips
            logger.info('loading clips...')
            for n, m in block.named_modules():
                if isinstance(m, nn.Linear):
                    layer_name = f"{n}.weight_quantizer."
                    m.register_buffer('buf_upbound_factor', self.weight_clips[idx][layer_name+'upbound_factor'].cuda().float())
                    m.register_buffer('buf_lowbound_factor', self.weight_clips[idx][layer_name+'lowbound_factor'].cuda().float())

    @torch.no_grad()
    def update_input_feat(self, scale, input_feat, layers_dict):
        for layer_name in layers_dict:
            for i in range(len(input_feat[layer_name])):
                inp = input_feat[layer_name][i]
                inp.div_(scale.view(1, -1).to(inp.device))

    def smooth_q_k_inplace(self, block):
        for name, module in block.named_modules():
            if isinstance(
                module, (LlmcLayerNorm, LlmcLlamaRMSNorm, LlmcMistralRMSNorm, LlmcQwen2RMSNorm)
            ):
                module.use_tmp_parameter = False

        if block.self_attn.q_proj.weight.shape != block.self_attn.k_proj.weight.shape:
            return

        scales = block.qkt_smooth_scale
        scales.data = self.truncate(scales)
        block.self_attn.q_proj.weight.div_(scales.view(-1, 1))
        if block.self_attn.q_proj.bias is not None:
            block.self_attn.q_proj.bias.div_(scales.view(-1))
        block.self_attn.k_proj.weight.mul_(scales.view(-1, 1))
        if block.self_attn.k_proj.bias is not None:
            block.self_attn.k_proj.bias.mul_(scales.view(-1))

    def cache_input_hook(self, m, x, y, name, feat_dict):
        super(GPTBRECQ, self).cache_input_hook(m, x, y, name, feat_dict)
        if len(feat_dict[name]) > 128:
            del feat_dict[name][-1]

    @torch.no_grad()
    def collect_block_qparams(self, block):
        named_linears = self.model.get_block_linears(block)
        for n, m in named_linears.items():
            args = {}
            if hasattr(m, "buf_lowbound_factor"):
                args["lowbound_factor"] = m.buf_lowbound_factor
            if hasattr(m, "buf_upbound_factor"):
                args["upbound_factor"] = m.buf_upbound_factor
            (
                tensor,
                scales,
                zeros,
                max_int,
                min_int,
            ) = self.wquantizer.get_tensor_qparams(m.weight.data, args=args)
            m.register_buffer("buf_scales", scales)
            m.register_buffer("buf_zeros", zeros)
            m.register_buffer("buf_max_int", torch.tensor(max_int).to(self.dev))
            m.register_buffer("buf_min_int", torch.tensor(min_int).to(self.dev))

    def w_qdq(self, module):
        weight = module.weight

        args = {}
        args["scales"] = module.buf_scales
        if hasattr(module, "buf_zeros"):
            args["zeros"] = module.buf_zeros
        else:
            args["zeros"] = None
        args["max_int"] = module.buf_max_int
        args["min_int"] = module.buf_min_int

        if module.rounding_opt:
            args["rounding"] = self.sigmoid(module.rounding)

        if self.optimize_scale:
            args['output_scale_factor'] = 2 * self.sigmoid(module.output_scale_factor)

        weight = self.wquantizer.fake_quant_weight_static(weight, args)

        return weight

    def deploy(self, quant_format):
        super().deploy(quant_format)
        self.model.convert_dtype(self.model_dtype)

    def save_model(self, path):
        self.model.convert_dtype(self.model_dtype)
        super().save_model(path)


class RectifiedSigmoid(nn.Module):
    def __init__(self, gamma, zeta):
        super(RectifiedSigmoid, self).__init__()
        self.gamma = gamma
        self.zeta = zeta

    def forward(self, x):
        return torch.clamp(torch.sigmoid(x)*(self.zeta-self.gamma) + self.gamma, 0, 1)

    def inverse(self, y):
        """return x that satisfies y = RectifiedSigmoid(x)"""
        return -torch.log((self.zeta-self.gamma)/(y-self.gamma) - 1)
