import logging
import time

import torch
from torch import nn
from torch._subclasses import FakeTensorMode
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PreTrainedModel

from gptq_core import HessianUtil, gptq_quantize, reconstruct_nn_linear, gptq_outliers_quantize
from utils import move_to_device, extract_dependencies


@torch.no_grad()
def get_pre_trained_model(model_path: str) -> PreTrainedModel:
    """
    Load a pre-trained model from its checkpoint
    """
    nn.init.kaiming_uniform_ = nn.init.uniform_ = nn.init.normal_ = lambda *args, **kwargs: None
    model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_path, torch_dtype='auto')
    return model


@torch.no_grad()
def get_initial_inputs(
        model: PreTrainedModel,
        encodings: torch.Tensor,
        device: torch.device,
        batch_size: int = 1,
        save_device: torch.device = torch.device('cpu'),
) -> tuple[torch.Tensor, dict]:
    """
    Catch first layer's input
    encodings: (B, N=SeqLen), int64
    inps: inputs, (B, N=SeqLen, C), fp16 or bf16
    attention_mask = None
    position_ids: (1, N=SeqLen), int64
    past_key_value: None
    output_attentions: bool = False
    use_cache: bool = False
    cache_position: (N=SeqLen), int64
    position_embeddings: tuple, fp16 or bf16, (1, N=SeqLen, 128), (1, N=SeqLen, 128)
    """
    use_cache, model.config.use_cache = model.config.use_cache, False
    gpt_blocks: nn.ModuleList = model.model.layers

    def catcher_hook(module, *args, **kwargs):
        raise ValueError(args, kwargs)

    catcher_handle: torch.utils.hooks.RemovableHandle = gpt_blocks[0].register_forward_pre_hook(catcher_hook, prepend=True, with_kwargs=True)

    auxiliary_layers: dict[str, nn.Module] = {k: v for k, v in model.model.named_children() if k != 'layers'}
    for auxiliary_layer_name, auxiliary_layer in auxiliary_layers.items():
        # TODO: below may be uncommented (for llama)
        # if auxiliary_layer_name == 'norm':
        #     continue
        auxiliary_layer.to(device=device)

    inps: list[torch.Tensor] = []
    inp_kwargs: dict = {}
    for bi in range(0, len(encodings), batch_size):
        try:
            model(encodings[bi : bi + batch_size].to(device=device))
        except ValueError as value_error:
            inps.append(value_error.args[0][0][0].to(device=save_device))
            inp_kwargs: dict = value_error.args[0][1]

    for auxiliary_layer in auxiliary_layers.values():
        auxiliary_layer.cpu()
    catcher_handle.remove()
    model.config.use_cache = use_cache
    return torch.cat(inps, dim=0), move_to_device(inp_kwargs, save_device)


class RecorderWrapper(nn.Module):
    """
    A wrapper to record the outputs of a linear layer
    """
    fake_tensor_mode: FakeTensorMode = FakeTensorMode(allow_fallback_kernels=False, allow_non_fake_inputs=True)
    mode_undefined: int = 0
    mode_default: int = 1
    mode_fake: int = 2
    mode_catch: int = 3
    mode_replay: int = 4

    def __init__(self, module: nn.Linear):
        super().__init__()
        self.module: nn.Linear = module
        self.mode: int = RecorderWrapper.mode_undefined
        self.outs: list[torch.Tensor] = []
        self.out_pointer: int = 0

    @torch.no_grad()
    def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
        device: torch.device = hidden_states.device
        match self.mode:
            case RecorderWrapper.mode_default:
                module_device: torch.device = self.module.weight.device
                out: torch.Tensor = self.module.to(device=device)(hidden_states, **kwargs)
                self.module.to(device=module_device)
                return out
            case RecorderWrapper.mode_fake:
                with RecorderWrapper.fake_tensor_mode:
                    out: torch.Tensor = torch.empty(
                        *hidden_states.shape[:-1], self.module.weight.size(0),
                        dtype=hidden_states.dtype,
                        device=device,
                    )
                return out
            case RecorderWrapper.mode_catch:
                raise ValueError(hidden_states)
            case RecorderWrapper.mode_replay:
                out: torch.Tensor = self.outs[self.out_pointer]
                self.out_pointer: int = (self.out_pointer + 1) % len(self.outs)
                return out.to(device=device)
            case _:
                raise NotImplementedError


@torch.no_grad()
def quantize_model(
        model: PreTrainedModel,
        encodings: torch.Tensor,
        device: torch.device,
        quant_group_size: int = 128,
        quant_bit_width: float = 4.,
        quant_order: str = 'act',
        quant_use_entropy_mode: str = 'none',
        quant_do_clip: bool = True,
        quant_use_mse: bool = True,
        batch_size: int = 1,
        save_gpu_mem_level: int = 2,  # whether to save input tensors in cpu ram, may impact performance
        do_rtn: bool = False,
        outlier_percentage: float | None = None,
) -> dict[str, dict[str, dict]]:
    """
    Start quantization
    save_gpu_mem_level: whether to save GPU memory
     1: cache all activations in GPU
     2: cache only initial activations for a block in GPU and keep layers' weights in GPU
     3: cache only initial activations for a block in CPU and keep layers' weights in GPU
     4: only load tensors to GPU when computing
    """
    _save_gpu_mem_level_low, _save_gpu_mem_level_mid, _save_gpu_mem_level_high, _save_gpu_mem_level_highest = 1, 2, 3, 4
    _input, _output = 'input', 'output'
    cpu_device: torch.device = torch.device('cpu')
    inps, inp_kwargs = get_initial_inputs(model, encodings, device, batch_size,
                                          save_device=cpu_device if save_gpu_mem_level >= _save_gpu_mem_level_high else device)
    if save_gpu_mem_level <= _save_gpu_mem_level_high:
        inp_kwargs: dict = move_to_device(inp_kwargs, device=device)
    use_cache, model.config.use_cache = model.config.use_cache, False
    dtype: torch.dtype = inps.dtype
    inp_kwargs_cpu: dict = move_to_device(inp_kwargs, device=cpu_device)

    results: dict[str, dict[str, dict]] = {
        'data': {},  # dict[str, dict[str, torch.Tensor | None]]
        'metrics': {},  # dict[str, dict[str, float]]
    }

    gpt_blocks: nn.ModuleList = model.model.layers

    for gi, gpt_block in enumerate(gpt_blocks):
        block_start_time: float = time.time()

        # find dependency info
        dependency_info: list[tuple] = extract_dependencies(gpt_block, nn.Linear, inps.shape, dtype, cpu_device, inp_kwargs_cpu)

        # wrap layers
        wrapper_layers_dict: dict[str, RecorderWrapper] = {}
        for linear_layer_name, linear_layer_module in gpt_block.named_modules():
            if isinstance(linear_layer_module, nn.Linear):
                linear_layer_wrapper: RecorderWrapper = RecorderWrapper(linear_layer_module)
                wrapper_layers_dict[linear_layer_name] = linear_layer_wrapper
                gpt_block.set_submodule(linear_layer_name, linear_layer_wrapper)

        # move norm layers to GPU
        for auxiliary_layer_name, auxiliary_layer in gpt_block.named_modules():
            if isinstance(auxiliary_layer, RecorderWrapper | nn.Linear) or next(auxiliary_layer.children(), None):
                continue
            # TODO: below may be uncommented (for llama)
            # if auxiliary_layer_name == 'self_attn.rotary_emb':
            #     continue  # this layer is not used in transformers.models.llama.modeling_llama.LlamaAttention
            auxiliary_layer.to(device=device)

        # start quantization
        for di, (quantizing_layer_names, to_release_layer_names) in enumerate(dependency_info):
            if quantizing_layer_names == [_output]:
                break

            # compute hessian
            hessian_util: HessianUtil = HessianUtil()
            for quantizing_layer_name in quantizing_layer_names:
                wrapper_layers_dict[quantizing_layer_name].mode = RecorderWrapper.mode_catch
            hidden_states: list[torch.Tensor] = []
            if save_gpu_mem_level >= _save_gpu_mem_level_highest:
                inp_kwargs: dict | None = move_to_device(inp_kwargs_cpu, device=device)
            for bi in range(0, len(inps), batch_size):
                try:
                    gpt_block(inps[bi : bi + batch_size].to(device=device), **inp_kwargs)
                except ValueError as value_error:
                    hidden_state: torch.Tensor = value_error.args[0]
                    hessian_util.add_batch(hidden_state, use_kernel=True)
                    if save_gpu_mem_level <= _save_gpu_mem_level_low:
                        hidden_states.append(hidden_state)
            if save_gpu_mem_level >= _save_gpu_mem_level_highest:
                inp_kwargs: dict | None = None

            if save_gpu_mem_level <= _save_gpu_mem_level_low:
                # parent layers no longer needed: output fake tensors
                for to_release_layer_name in to_release_layer_names:
                    if to_release_layer_name == _input:
                        inps = RecorderWrapper.fake_tensor_mode.from_tensor(inps)
                        continue
                    to_release_layer_wrapper: RecorderWrapper = wrapper_layers_dict[to_release_layer_name]
                    to_release_layer_wrapper.outs = []
                    to_release_layer_wrapper.mode = RecorderWrapper.mode_fake

            hessian_util.invert(order=quant_order, damp_ratio=1e-2)

            # quantize a layer
            for quantizing_layer_name in quantizing_layer_names:
                assert quantizing_layer_name not in [_input, _output]
                canonical_name: str = f'model.layers.{gi}.{quantizing_layer_name}'
                quantizing_layer_wrapper: RecorderWrapper = wrapper_layers_dict[quantizing_layer_name]
                weight: torch.Tensor = quantizing_layer_wrapper.module.weight.data.to(device=device)

                if outlier_percentage is not None:
                    # SSQR quantization
                    gptq_result: dict[str, dict] = gptq_outliers_quantize(
                        weight=weight,
                        hessian_util=hessian_util,
                        quant_group_size=quant_group_size,
                        quant_bit_width=quant_bit_width,
                        quant_use_entropy_mode=quant_use_entropy_mode,
                        quant_symmetric=True,
                        quant_use_mse=quant_use_mse,
                        quant_max_shrink=.8,
                        quant_n_grid=100,
                        quant_norm=2.4,
                        save_device=device,
                        quant_outlier_percentage=outlier_percentage
                    )
                else:
                    gptq_result: dict[str, dict] = gptq_quantize(
                        weight=weight,
                        hessian_util=hessian_util,
                        quant_group_size=quant_group_size,
                        quant_bit_width=quant_bit_width,
                        quant_use_entropy_mode=quant_use_entropy_mode,
                        quant_symmetric=True,
                        quant_do_clip=quant_do_clip,
                        quant_use_mse=quant_use_mse,
                        quant_max_shrink=.8,
                        quant_n_grid=100,
                        quant_norm=2.4,
                        save_device=device,
                        do_rtn=do_rtn
                    )
                # logging.debug(f"{canonical_name} {gptq_result['metrics']}")

                # reconstruct layer and record outputs
                reconstructed_nn_linear: nn.Linear = reconstruct_nn_linear(
                    gptq_result['quant_meta'],
                    bias=quantizing_layer_wrapper.module.bias,
                    dtype=dtype,
                    device=device,
                )
                results['data'][canonical_name] = move_to_device(gptq_result['quant_meta'], cpu_device)
                quantizing_layer_wrapper.module = reconstructed_nn_linear

                # log metrics
                metrics: dict = gptq_result['metrics']
                del weight, gptq_result
                results['metrics'][canonical_name] = metrics
                # logging.debug(f'{canonical_name} {metrics}')

                if save_gpu_mem_level <= _save_gpu_mem_level_low:
                    for hidden_state in hidden_states:
                        quantizing_layer_wrapper.outs.append(quantizing_layer_wrapper.module(hidden_state))
                    quantizing_layer_wrapper.module.cpu()
                    quantizing_layer_wrapper.mode = RecorderWrapper.mode_replay
                elif save_gpu_mem_level >= _save_gpu_mem_level_highest:
                    quantizing_layer_wrapper.module.cpu()
                    quantizing_layer_wrapper.mode = RecorderWrapper.mode_default
                else:
                    quantizing_layer_wrapper.mode = RecorderWrapper.mode_default

            del hessian_util, hidden_states

        # inputs of the next block
        if save_gpu_mem_level >= _save_gpu_mem_level_highest:
            inp_kwargs: dict | None = move_to_device(inp_kwargs_cpu, device=device)
        for bi in range(0, len(inps), batch_size):
            inps[bi : bi + batch_size], = gpt_block(inps[bi : bi + batch_size].to(device=device), **inp_kwargs)
        if save_gpu_mem_level >= _save_gpu_mem_level_highest:
            inp_kwargs: dict | None = None

        # un-wrap layers
        for linear_layer_name, linear_layer_wrapper in wrapper_layers_dict.items():
            gpt_block.set_submodule(linear_layer_name, linear_layer_wrapper.module)
        gpt_block.cpu()

        block_end_time: float = time.time()
        logging.info(f'finished block {gi} in {block_end_time - block_start_time:.2f} s')

    model.config.use_cache = use_cache
    return results


@torch.no_grad()
def evaluate_model(
        model: PreTrainedModel,
        encodings: torch.Tensor,
        device: torch.device,
        batch_size: int = 1,
        save_gpu_mem: bool = True,
) -> torch.Tensor:
    """
    Evaluate the model with perplexity
    """
    cpu_device: torch.device = torch.device('cpu')
    inps, inp_kwargs = get_initial_inputs(model, encodings, device, batch_size,
                                          save_device=cpu_device if save_gpu_mem else device)
    inp_kwargs: dict = move_to_device(inp_kwargs, device=device)
    use_cache, model.config.use_cache = model.config.use_cache, False

    gpt_blocks: nn.ModuleList = model.model.layers
    for i, gpt_block in tqdm(enumerate(gpt_blocks), total=len(gpt_blocks)):
        gpt_block.to(device=device)
        for j in range(0, len(inps), batch_size):
            inps[j:j+batch_size], = gpt_block(inps[j:j+batch_size].to(device=device), **inp_kwargs)
        gpt_block.cpu()

    if model.model.norm is not None:
        model.model.norm.to(device=device)
        for j in range(0, len(inps), batch_size):
            inps[j:j+batch_size] = model.model.norm(inps[j:j+batch_size].to(device=device))
        model.model.norm.cpu()

    batch_size: int = 1  # saving memory, the logits tensor is large
    model.lm_head.to(device=device)
    loss_fct: nn.Module = nn.CrossEntropyLoss()
    nlls: list[torch.Tensor] = []
    for i in range(0, len(inps), batch_size):
        shift_logits = model.lm_head(inps[i:i+batch_size].to(device=device))[:, :-1, :]  # (B, N=SeqLen-1, C)
        shift_labels = encodings[i:i+batch_size, 1:].to(device=device)  # (B, N=SeqLen-1)
        neg_log_likelihood = loss_fct(shift_logits.flatten(end_dim=-2), shift_labels.flatten())
        nlls.extend([neg_log_likelihood] * len(shift_logits))
    ppl: torch.Tensor = torch.stack(nlls).to(dtype=torch.float32).mean().exp()
    model.lm_head.cpu()

    model.config.use_cache = use_cache
    return ppl
