import copy
import math
from typing import List, Optional

import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
from transformers import AutoModelForCausalLM

import constants
from constants import DEVICE_CPU
from models.model_utils import get_model_layers
from quantization.quantization_flows.rtn import build_quantizer_kwargs
from quantization.set_block_utils import set_block_quantizers_and_transforms
from quantization.transforms.transforms_builder import set_block_R1_learned_transforms, set_block_R2_learned_transforms, \
    non_learned_transforms_builder, build_R1_learned_transform
from run_config import RunConfig
from transform_optimization.opt_config import OptimizationConfig
from transform_optimization.prepare import prepare_model_for_transform_training, build_loss_function, \
    prepare_train_dataloader, wrap_up_training
from transform_optimization.train_transform import train_transform_matrix
from quantization.qlinear import QLinear
from quantization.quantizer import Quantizer
from quantization.quant_args import QuantizationOrder


try:
    from quantization.utils.accumulate_hessian import accumulate_hessian
except ImportError:
    # This is to allow running RTN on MPS. GPTQ in not available on MPS
    accumulate_hessian = None


from quantization.utils.linalg_utils import inv_sym
from quantization.utils.common_utils import clear_device_cache, to, maybe_first_element
from quantization.quantized_modules.model_utils import InputCollector, ForwardInterrupt, \
    get_number_of_rows_and_cols, load_quantized_modules_state_dict
from quantization.quant_config import QuantConfig

try:
    import wandb
except ImportError:
    wandb = None


def get_relative_mse_error(q: torch.Tensor, w: torch.Tensor, H: torch.Tensor):
    delta = q - w
    return (delta).mm(H).mul(delta).mean() / (w.mm(H).mul(w).mean() + 1e-6)


class GPTQ:

    def __init__(
        self,
        layer: nn.Module,
        quantizer: Quantizer,
        quantization_order: str = "default",
        block_size: int = 128,
        rel_damp: float = 1e-2,
    ):
        self._validate_layer(layer)
        self.layer = layer
        self.W = self.layer.weight
        self.d_row, self.d_col = get_number_of_rows_and_cols(layer)
        # Quantization properties
        self.quantizer = quantizer
        self.quantization_order = QuantizationOrder(quantization_order)
        self.block_size = block_size
        self.rel_damp = rel_damp
        # Backup layer properties
        self.W_device = self.W.device
        self.W_dtype = self.W.dtype
        self.W_shape = self.W.shape
        # init hessian
        self.H = None
        self.num_samples = 0
        # Whether to apply real quantization

    @staticmethod
    def _validate_layer(layer):
        assert isinstance(layer, (nn.Linear, _ConvNd)), "OBC supports only linear and convolutional layers."

    # preparatory methods
    @torch.no_grad()
    def update(self, input: torch.Tensor) -> None:
        """
        Update the estimate of Hessian matrix from a batch of data.

        Args:
            input: batch of layer inputs
        """
        # get batch size
        batch_size = input.shape[0]
        # init hessian
        if self.H is None:
            self.H = torch.zeros((self.d_col, self.d_col), device=input.device, dtype=torch.float32)
        # input reshaping
        if isinstance(self.layer, nn.Linear):
            input = input.reshape(-1, input.shape[-1])
        else:
            unfold = nn.Unfold(
                self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,
                stride=self.layer.stride,
            )
            # output size (batch_size, channels * \prod kernel_size, num_patches)
            input = unfold(input)
            input = input.transpose(1, 2).flatten(0, 1)
        # cast input to float32 before addition
        input = input.float()
        # rescale and update matrix
        beta = self.num_samples / (self.num_samples + batch_size)
        alpha = 2.0 / (self.num_samples + batch_size)
        self.H.mul_(beta)
        input.mul_(math.sqrt(alpha))
        accumulate_hessian(self.H, input)
        self.num_samples += batch_size

    def reset(self) -> None:
        self.W = self.layer.weight
        self.H = None
        self.num_samples = 0
        clear_device_cache()

    @torch.no_grad()
    def quantization_pre_step(self) -> None:
        """
        Preparatory step with hessian regularization and weight reshaping.
        """
        # 1) Hessian preparation
        assert self.H is not None, "One has to process at least one sample of calibration data to run pruning"
        # 2) Weight preparation
        # copy weight, flatten and convert to float
        self.W = self.W.clone().float()
        if isinstance(self.layer, _ConvNd):
            self.W = self.W.flatten(1, -1)
        # flag pre step as completed
        self.pre_step_completed = True

    @torch.no_grad()
    def step(self) -> torch.Tensor | Optional[torch.Tensor] | torch.Tensor:
        """
        Quantize the weight matrix using GPTQ
        """
        # 1) Define constants and chunk
        d_col, block_size, device, dtype = self.d_col, self.block_size, self.W_device, self.W_dtype
        # 2) Get quantization group size
        quantizer_group_size = self.quantizer.group_size
        group_size = quantizer_group_size or d_col
        num_groups = d_col // group_size

        # Init quantized weight
        qweight = None
        # Get scales and zeros 
        scales, zeros = self.quantizer.get_quantization_params(self.W) 
        # Dirty hack for GPTQ quantization
        self.quantizer.group_size = None
        # Get permutation
        if self.quantization_order == QuantizationOrder.ACTIVATION:
            perm = torch.argsort(self.H.diag(), descending=True)
            group_idx = torch.arange(num_groups, device=device).repeat_interleave(group_size)[perm]
        else:
            perm = torch.arange(d_col, device=device)
        perm_inv = torch.argsort(perm)
        # Permute Hessian prior to inversion
        self.H = self.H[perm][:, perm]
        # Get weight
        w = self.W[:, perm]
        # Get Hessian inverse   
        H_inv_cho = self._get_hessian_inverse()
        # Quantize
        for c1 in range(0, d_col, block_size):
            c2 = min(c1 + block_size, d_col)
            ncols = c2 - c1
            w_blk = w[:, c1:c2].clone()  
            errs = torch.zeros_like(w_blk)
            H_inv_cho_blk = H_inv_cho[c1:c2, c1:c2]
            # 2) Iterate over block
            for i in range(ncols):
                # Get weight column, corresponding Hessian diagonal and group_id
                w_ci = w_blk[:, i]
                d = H_inv_cho_blk[i, i]
                if self.quantization_order == QuantizationOrder.ACTIVATION:
                    g_idx = group_idx[c1 + i]
                else:
                    g_idx = (c1 + i) // group_size    

                w_q = self.quantizer(w_ci, scales[:, g_idx], zeros[:, g_idx])
                w[:, c1 + i] = w_q
                # Update subsequent weight
                err = (w_ci - w_q) / d
                w_blk[:, i:].addr_(err, H_inv_cho_blk[i, i:], alpha=-1)
                errs[:, i] = err
            # 3) Update the weights after block
            w[:, c2:].addmm_(errs, H_inv_cho[c1:c2, c2:], alpha=-1)

        # Invert permutation
        w = w[:, perm_inv].contiguous()
        if qweight is not None:
            qweight = qweight[:, perm_inv].contiguous()
        self.H = self.H[perm_inv][:, perm_inv]
        # Restore quantizer group size
        self.quantizer.group_size = quantizer_group_size
        
        return w.to(dtype), qweight, scales
    
    @torch.no_grad()
    def _get_hessian_inverse(self):
        w = self.W
        # Get columns with all zeros
        zero_cols = torch.nonzero(w.eq(0).all(dim=0))
        H = self.H
        # mask rows with zero input channels
        H[zero_cols, :] = 0
        H[:, zero_cols] = 0
        H[zero_cols, zero_cols] = 1
        # Hessian regularization
        damp = self.rel_damp * torch.diag(self.H).mean()
        self.H[range(self.d_col), range(self.d_col)] += damp
        # invert
        try:
            H = inv_sym(H)
            H_inv_cho = torch.linalg.cholesky(H, upper=True)
        except:
            H_inv_cho = torch.eye(self.d_col, device=H.device, dtype=torch.float32)
        return H_inv_cho

    def quantize(self) -> torch.Tensor | Optional[torch.Tensor] | torch.Tensor:
        self.quantization_pre_step()
        return self.step()


def _remove_block_parametrization(qkv_in_transform, o_in_transform, v_out_transform, down_in_transform, gate_up_in_transform):
    # Remove parametrizations
    qkv_in_transform.remove_parametrizations()
    o_in_transform.remove_parametrizations()
    down_in_transform.remove_parametrizations()
    gate_up_in_transform.remove_parametrizations()
    if v_out_transform:
        v_out_transform.remove_parametrizations()


def _set_gptq_handles(block, weight_quantizer_kwargs, quant_config):
    gptq_handles = {}
    hooks = {}
    for layer_name, layer in block.named_modules():
        if isinstance(layer, QLinear):
            # Create GPTQ handle
            gptq_handles[layer_name] = GPTQ(
                layer,
                Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
                quantization_order=quant_config.quantization_order,
                rel_damp=quant_config.rel_damp,
            )

            # Attach hook

            # Attach hook
            def update_handle_hook(name):
                def _hook(_, inp, out):
                    gptq_handles[name].update(inp[0])

                return _hook

            hooks[layer_name] = layer.register_forward_hook(update_handle_hook(layer_name))

    return gptq_handles, hooks


def _transform_weights_before_quantization(block, qkv_in_transform, o_in_transform, v_out_transform, gate_up_in_transform, down_in_transform):
    if block.self_attn.q_proj.norm_gamma is not None:
        assert block.self_attn.k_proj.norm_gamma is not None and block.self_attn.v_proj.norm_gamma is not None, \
            "All QKV projections must have norm_gamma if one of them has it."
        og_dtype = block.self_attn.q_proj.weight.dtype
        wq = (block.self_attn.q_proj.weight @ torch.diag(block.self_attn.q_proj.norm_gamma.to(og_dtype)))
        wk = (block.self_attn.k_proj.weight @ torch.diag(block.self_attn.k_proj.norm_gamma.to(og_dtype)))
        wv = (block.self_attn.v_proj.weight @ torch.diag(block.self_attn.v_proj.norm_gamma.to(og_dtype)))
        block.self_attn.q_proj.weight.data = qkv_in_transform(wq, inv_t=True)
        block.self_attn.k_proj.weight.data = qkv_in_transform(wk, inv_t=True)
        block.self_attn.v_proj.weight.data = qkv_in_transform(wv, inv_t=True)
    else:
        block.self_attn.q_proj.weight.data = qkv_in_transform(block.self_attn.q_proj.weight, inv_t=True)
        block.self_attn.k_proj.weight.data = qkv_in_transform(block.self_attn.k_proj.weight, inv_t=True)
        block.self_attn.v_proj.weight.data = qkv_in_transform(block.self_attn.v_proj.weight, inv_t=True)

    if v_out_transform:
        # R2 on V output
        tran_dim = v_out_transform.block_size  # TODO: verify for block-wise transform
        W_ = block.self_attn.v_proj.weight.t()
        transposed_shape = W_.shape
        temp = W_.reshape(-1, transposed_shape[-1] // tran_dim, tran_dim)
        temp = v_out_transform(temp, inv_t=False, dim=-1)
        block.self_attn.v_proj.weight.data = temp.reshape(transposed_shape).t()

        ## bias
        if block.self_attn.v_proj.bias is not None:
            W_ = block.self_attn.v_proj.bias
            transposed_shape = W_.shape
            temp = W_.reshape(-1, transposed_shape[-1] // tran_dim, tran_dim)
            temp = v_out_transform(temp, inv_t=False, dim=-1)
            block.self_attn.v_proj.bias.data = temp.reshape(transposed_shape)

        # R2_inv on O input
        W_ = block.self_attn.o_proj.weight
        init_shape = W_.shape
        temp = W_.reshape(-1, init_shape[-1] // tran_dim, tran_dim)
        temp = o_in_transform(temp, inv_t=True, dim=-1)  # TODO: verify dim=-1 no dim=0 is correct
        block.self_attn.o_proj.weight.data = temp.reshape(init_shape)
    else:
        block.self_attn.o_proj.weight.data = o_in_transform(block.self_attn.o_proj.weight, inv_t=True)

    if block.mlp.gate_proj.norm_gamma is not None:
        assert block.mlp.up_proj.norm_gamma is not None
        og_dtype = block.mlp.gate_proj.weight.dtype
        wg = (block.mlp.gate_proj.weight @ torch.diag(block.mlp.gate_proj.norm_gamma.to(og_dtype)))
        wu = (block.mlp.up_proj.weight @ torch.diag(block.mlp.up_proj.norm_gamma.to(og_dtype)))
        block.mlp.gate_proj.weight.data = gate_up_in_transform(wg, inv_t=True)
        block.mlp.up_proj.weight.data = gate_up_in_transform(wu, inv_t=True)
    else:
        block.mlp.gate_proj.weight.data = gate_up_in_transform(block.mlp.gate_proj.weight, inv_t=True)
        block.mlp.up_proj.weight.data = gate_up_in_transform(block.mlp.up_proj.weight, inv_t=True)
    block.mlp.down_proj.weight.data = down_in_transform(block.mlp.down_proj.weight, inv_t=True)


def _run_gptq(block, gptq_handles, quant_config, input_args, input_kwargs, device):
    for layer_name, gptq_handle in gptq_handles.items():
        dequantized_qweight, qweight, scales = gptq_handle.quantize()
        orig_weight = gptq_handle.layer.weight
        with torch.no_grad():
            relative_mse_error = get_relative_mse_error(dequantized_qweight.float(), orig_weight.float(), gptq_handle.H)
        print(f"[{layer_name:16}]: Relative MSE error: {relative_mse_error.item():.2e}")
        gptq_handle.layer.weight.data = dequantized_qweight

    # 8. Update activations
    device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
    for inp_args, inp_kwargs in zip(input_args, input_kwargs):
        with torch.no_grad(), torch.amp.autocast(device_type=device_type, enabled=quant_config.amp):
            out = block(*to(inp_args, device=device), **to(inp_kwargs, device=device))
        out = maybe_first_element(out).to(device)
        # change only first input argument
        if len(inp_args) > 0:
            inp_args[0].data = out
        elif "hidden_states" in inp_kwargs:
            inp_kwargs["hidden_states"] = out
        else:
            raise ValueError("Unsupported block input format.")


def gptq_quantization(
    model: AutoModelForCausalLM, 
    calibration_data: List[torch.Tensor],
    quant_config: QuantConfig,
    run_config: RunConfig,
    opt_config: OptimizationConfig = None,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:

    print("GPTQ quantization...")
    device = run_config.device
    is_learned_transform = quant_config.transform_class_r1 in ["learned", "learned_affine"] or quant_config.transform_class_r2 in ["learned", "learned_affine"]

    float_model = None
    if is_learned_transform:
        float_model = copy.deepcopy(model).to(DEVICE_CPU)

    # State dict with quantized weights, scales and hadamards
    quantized_state_dict = {}
    non_quantized_state_dict = {}

    # Define common transform kwargs
    transform_kwargs = dict(device=device, group_size=quant_config.hadamard_group_size)

    # Init quantizer kwargs
    weight_quantizer_kwargs , act_quantizer_kwargs = build_quantizer_kwargs(quant_config)

    # Get transformer blocks (supports multiple architectures)
    blocks = get_model_layers(model)

    blocks[0] = InputCollector(blocks[0], cpu_offload=False)

    for sample in calibration_data:
        try:
            with torch.no_grad():
                model(sample.to(device=device))
        except ForwardInterrupt:
            pass

    input_args = blocks[0].input_args
    input_kwargs = blocks[0].input_kwargs
    blocks[0] = blocks[0].module

    shared_R1_learned_transform = None
    head_dim = getattr(model.config, "head_dim", None)
    if is_learned_transform:
        # Prepare transforms for learned transform optimization
        if head_dim is None:
            head_dim = model.config.hidden_size // model.config.num_attention_heads

        assert model.config.hidden_size % quant_config.group_size == 0, "Hidden size must be divisible by group size for learned transforms."
        shared_R1_learned_transform = build_R1_learned_transform(model, opt_config, quant_config, device)

    R2_transforms = {}
    block_to_modules_map = {}
    # Iterate over transformer blocks
    for block_idx, block in enumerate(blocks):
        print(f"Building block {block_idx}...")

        if is_learned_transform:
            qkv_in_transform, gate_up_in_transform, down_in_transform = set_block_R1_learned_transforms(model,
                                                                                                        shared_R1_learned_transform,
                                                                                                        quant_config,
                                                                                                        device)
            o_in_transform, v_out_transform = set_block_R2_learned_transforms(model, quant_config, opt_config, head_dim,
                                                                              device)
        else:
            qkv_in_transform, o_in_transform, gate_up_in_transform, down_in_transform, v_out_transform = (
                non_learned_transforms_builder(model, quant_config, transform_kwargs))

        # 2. Replace blocks with quantized versions
        quantized_attn, quantized_mlp, transformed_input_attn_norm, transformed_input_mlp_norm = (
            set_block_quantizers_and_transforms(model, block, block_idx, weight_quantizer_kwargs, act_quantizer_kwargs,
                                                qkv_in_transform, o_in_transform, gate_up_in_transform,
                                                down_in_transform, v_out_transform))

        block_to_modules_map[block_idx] = {
            constants.MODULE_QUANTIZED_ATTN: quantized_attn,
            constants.MODULE_QUANTIZED_MLP: quantized_mlp,
            constants.MODULE_QKV_IN_TRANSFORM: qkv_in_transform,
            constants.MODULE_O_IN_TRANSFORM: o_in_transform,
            constants.MODULE_GATE_UP_IN_TRANSFORM: gate_up_in_transform,
            constants.MODULE_DOWN_IN_TRANSFORM: down_in_transform,
            constants.MODULE_V_OUT_TRANSFORM: v_out_transform
        }

        load_quantized_modules_state_dict(
            block,
            quantized_attn,
            quantized_mlp,
            transformed_input_attn_norm,
            transformed_input_mlp_norm,
            model.config
        )

        # Move to original device and dtype
        _ = block.to(device=device, dtype=model.config.torch_dtype)

    #########################################
    # Train Transform
    #########################################

    if is_learned_transform:
        prepare_model_for_transform_training(model, shared_R1_learned_transform, R2_transforms)

        assert float_model is not None, "Float model must be provided for learned transform optimization."

        loss_fn = build_loss_function(model, float_model, opt_config)

        train_loader = prepare_train_dataloader(calibration_data, run_config)

        train_transform_matrix(model=model,
                               shared_learned_transforms=[shared_R1_learned_transform,
                                                          *list(R2_transforms.values())],
                               opt_config=opt_config,
                               dataloader=train_loader,
                               loss_fn=loss_fn,
                               device=device)

        wrap_up_training(shared_R1_learned_transform, R2_transforms)

    #########################################

    for block_idx, block in enumerate(blocks):
        # Toggle off gradients for all parameters
        block.requires_grad_(False)

        block_modules = block_to_modules_map[block_idx]
        qkv_in_transform = block_modules[constants.MODULE_QKV_IN_TRANSFORM]
        o_in_transform = block_modules[constants.MODULE_O_IN_TRANSFORM]
        down_in_transform = block_modules[constants.MODULE_DOWN_IN_TRANSFORM]
        gate_up_in_transform = block_modules[constants.MODULE_GATE_UP_IN_TRANSFORM]
        v_out_transform = None
        if block_modules[constants.MODULE_V_OUT_TRANSFORM]:
            v_out_transform = block_modules[constants.MODULE_V_OUT_TRANSFORM]

        _remove_block_parametrization(qkv_in_transform, o_in_transform, v_out_transform, down_in_transform, gate_up_in_transform)

        # Create GPTQ handles and hooks
        gptq_handles, hooks = _set_gptq_handles(block, weight_quantizer_kwargs, quant_config)

        # Process calibration data
        for inp_args, inp_kwargs in zip(input_args, input_kwargs):
            with torch.no_grad():
                block(*to(inp_args, device=device), **to(inp_kwargs, device=device))
        # Remove hooks
        for hook in hooks.values():
            hook.remove()

        # Transform all weights before quantization
        _transform_weights_before_quantization(block, qkv_in_transform, o_in_transform, v_out_transform, gate_up_in_transform, down_in_transform)

        # Set train_mode to False
        for layer_name, layer in block.named_modules():
            if isinstance(layer, QLinear):
                layer._train_mode = False
                if layer.act_quantizer:
                    layer.act_quantizer._track_global_scale = False

        # Run GPTQ quantization
        _run_gptq(block, gptq_handles, quant_config, input_args, input_kwargs, device)

        # 8. Clean-up
        del gptq_handles
        del hooks
        clear_device_cache(garbage_collection=True)

    clear_device_cache(garbage_collection=True)

    return quantized_state_dict, non_quantized_state_dict
