import logging
import time
from math import inf
from typing import Any

import torch
from hydra_zen.typing import Partial
from optimum.quanto import QuantizedModelForCausalLM
from transformers import BitsAndBytesConfig, HqqConfig, PreTrainedModel

from entquant.compress.manager import CompressionManager
from entquant.compress.utils import dispatch_helper, get_memory_stats
from entquant.entquant.manager import EntQuantManager
from entquant.model.factory import BaseModelConfig, load_pretrained_model, save_pretrained_model
from entquant.super_weights.super_weights import find_super_weights, SuperWeight, SuperWeightsConfig
from entquant.utils import clear_cache
from run.hydra_zen import register_workflow

logger = logging.getLogger(__name__)


@register_workflow("build")
def build_base_model(
    base_model_config: BaseModelConfig = "${cfg.model}",
) -> tuple[PreTrainedModel, dict[str, Any]]:
    pretrained_model_kwargs = {
        "device_map": base_model_config.device_map,
        "dtype": base_model_config.dtype,
        **base_model_config.pretrained_model_kwargs,
    }
    model = load_pretrained_model(
        base_model_config.base_model_name_or_path,
        base_model_config.pretrained_model_cls,
        pretrained_model_kwargs,
        base_model_config.generation_config,
    )
    logger.info(f"Created pretrained model: {model.name_or_path}")
    logger.debug(f"Model structure:\n{model}")

    results = {}

    clear_cache()
    mem_allocated, mem_reserved = get_memory_stats()
    logger.info(f"Memory allocated: {mem_allocated / (1024**3):8.3f} GiB")
    logger.info(f"Memory reserved: {mem_reserved / (1024**3):8.3f} GiB")
    results.update({"memory_allocated": mem_allocated})
    results.update({"memory_reserved": mem_reserved})

    return model, results


def _find_super_weights(
    model: PreTrainedModel | None,
    base_model_config: BaseModelConfig | None,
    super_weight_config: SuperWeightsConfig,
    device_map: str | dict[str, str | torch.device | int] = "cpu",
) -> dict[str, list[SuperWeight]]:
    if model is None:
        assert base_model_config is not None
        logger.info(f"Pre-loading model to find super weights: {base_model_config.base_model_name_or_path}")
        _device_map = base_model_config.device_map
        base_model_config.device_map = device_map
        model, _ = build_base_model(base_model_config)
        base_model_config.device_map = _device_map  # restore original device map for actual built
    else:
        assert model is not None

    super_weights = find_super_weights(model=model, config=super_weight_config)
    for module_name, module_sw in super_weights.items():
        if len(module_sw) > 0:
            logger.info(f"Super weights in {module_name}: {[(sw.row, sw.col) for sw in module_sw]}")

    clear_cache()
    return super_weights


@register_workflow("build")
def build_entquant_model(
    base_model_config: BaseModelConfig | None = "${cfg.model}",
    super_weight_config: SuperWeightsConfig | None = "${cfg.super_weights}",
    entquant: dict = "${cfg.entquant}",
    save_config: dict = "${run.save_model}",
) -> tuple[QuantizedModelForCausalLM, dict[str, Any]]:
    _build_t0 = time.perf_counter()
    model, results = build_base_model(base_model_config)
    results.update({"base_model_build_time_s": time.perf_counter() - _build_t0})
    logger.info(f"Base model build time: {results['base_model_build_time_s']:.3f}s")

    if base_model_config.pretrained_model_cls != QuantizedModelForCausalLM:
        # Only quantize if model is not already quantized (loaded via QuantizedModelForCausalLM.from_pretrained)
        _sw_t0 = time.perf_counter()
        if super_weight_config is not None and super_weight_config.spike_threshold != inf:
            super_weights = _find_super_weights(model, None, super_weight_config)
            results.update({"super_weights": super_weights})
            results.update({"super_weights_time_s": time.perf_counter() - _sw_t0})
            logger.info(f"Super weights computation time: {results['super_weights_time_s']:.3f}s")
        else:
            super_weights = None
            results.update({"super_weights": None})
            results.update({"super_weights_time_s": None})

        quant_manager = EntQuantManager(
            model=model,
            quanto_config=entquant["config"],
            super_weights=super_weights,
        )
        _quant_t0 = time.perf_counter()
        quant_manager.quantize()
        results.update({"quantization_time_s": time.perf_counter() - _quant_t0})
        logger.info(f"Quantization time: {results['quantization_time_s']:.3f}s")

        results.update({"entropy_stats": quant_manager.entropy()})
        logger.info(
            f"\n{'=' * 80}\n"
            f"ENTROPY STATISTICS\n"
            f"Average entropy: {results['entropy_stats']['average_entropy']:.3f}\n"
            f"Average sparsity: {results['entropy_stats']['average_sparsity']:.3f}"
            f"\n{'=' * 80}\n"
        )
    else:
        results.update({"quantization_time_s": None})
        logger.info("Quantization skipped (model already quantized).")

    model = QuantizedModelForCausalLM(model)

    if save_config["path"] is not None:
        save_pretrained_model(
            model,  # type: ignore
            save_config["path"],
            include_filter=save_config["include_filter"],
            exclude_filter=save_config["exclude_filter"],
            save_kwargs=save_config["kwargs"],
        )

    entquant_config = entquant.get("compress") or {}
    if isinstance(entquant_config, dict):
        if entquant_config.get("dispatch_device_map") is not None:
            dispatch_helper(model, entquant_config["dispatch_device_map"])  # noqa
            logger.info(f"Model dispatched to {entquant_config['dispatch_device_map']}.")
        results.update({"compression_time_s": None})
        results.update({"compression_stats": None})
        logger.info("Compression skipped.")
    else:
        # TODO: Need to move to host because otherwise dispatch_model does not work properly.
        model = model.cpu()

        compression_manager: Partial[CompressionManager] = entquant_config
        compression_manager: CompressionManager = compression_manager(
            model=model,
            # TODO: make "model.layers" configurable (check compression_manager.keywords)
            target_layer_include=[f"*model.layers.{i}" for i in range(len(model.model.layers))],
        )
        _compress_t0 = time.perf_counter()
        compression_manager.compress()
        results.update({"compression_time_s": time.perf_counter() - _compress_t0})
        logger.info(f"Compression time: {results['compression_time_s']:.3f}s")
        stats = compression_manager.compression_ratio()
        logger.info(
            f"\n{'=' * 80}\n"
            f"COMPRESSION STATISTICS\n"
            f"{'=' * 80}\n"
            f"Injected parameters:\n"
            f"  Original size:    {stats['injected_original_bytes'] / (1024**3):8.3f} GiB\n"
            f"  Compressed size:  {stats['injected_compressed_bytes'] / (1024**3):8.3f} GiB\n"
            f"  Compression ratio: {stats['ratio']:7.2f}x ({100.0 / stats['ratio']:5.2f}% of original)\n\n"
            f"Full model:\n"
            f"  Original size:    {stats['full_original_bytes'] / (1024**3):8.3f} GiB\n"
            f"  Compressed size:  {stats['full_compressed_bytes'] / (1024**3):8.3f} GiB\n"
            f"  Compression ratio: {stats['full_ratio']:7.2f}x ({100.0 / stats['full_ratio']:5.2f}% of original)\n"
            f"{'=' * 80}"
        )
        results.update({"compression_stats": stats})

    mem_allocated, mem_reserved = get_memory_stats()
    logger.info(f"Memory allocated: {mem_allocated / (1024**3):8.3f} GiB")
    logger.info(f"Memory reserved: {mem_reserved / (1024**3):8.3f} GiB")
    results.update({"memory_allocated": mem_allocated})
    results.update({"memory_reserved": mem_reserved})

    return model, results


def _patch_hqq_compute_dtype(compute_dtype: torch.dtype):
    from transformers.quantizers.quantizer_hqq import HqqHfQuantizer

    _original_init = HqqHfQuantizer.__init__

    def _patched_init(self, quantization_config, **kwargs):
        _original_init(self, quantization_config, **kwargs)
        self.dtype = compute_dtype

    HqqHfQuantizer.__init__ = _patched_init


@register_workflow("build")
def build_quantized_model(
    base_model_config: BaseModelConfig = "${cfg.model}",
    super_weight_config: SuperWeightsConfig | None = "${cfg.super_weights}",
    quantization_config: Partial = "${cfg.quantization.config}",
    modules_to_exclude: list[str] = "${cfg.quantization.modules_to_exclude}",
) -> tuple[PreTrainedModel, dict[str, Any]]:
    results = {}

    modules_to_exclude = list(modules_to_exclude)  # TODO: may be still a ListConfig

    if super_weight_config is not None and super_weight_config.spike_threshold != inf:
        super_weights = _find_super_weights(None, base_model_config, super_weight_config)
        for module_name, module_sw in super_weights.items():
            if len(module_sw) > 0:
                modules_to_exclude.append(module_name)
        results.update({"super_weights": super_weights})
    else:
        results.update({"super_weights": None})

    if isinstance(quantization_config(), BitsAndBytesConfig):
        quantization_config = quantization_config(llm_int8_skip_modules=modules_to_exclude)
    elif isinstance(quantization_config(), HqqConfig):
        quantization_config = quantization_config(skip_modules=modules_to_exclude)
    elif hasattr(quantization_config(), "modules_to_not_convert"):
        quantization_config = quantization_config(modules_to_not_convert=modules_to_exclude)
    else:
        logger.warning(f"Quantization config {quantization_config} is not supported, ignoring excluded modules.")
        quantization_config = quantization_config()

    # TODO: For some reason, compute_dtype is not properly passed to HqqHfQuantizer.
    #  Alternative: use native hqq backend.
    if isinstance(quantization_config, HqqConfig):
        _patch_hqq_compute_dtype(base_model_config.dtype)

    logger.info(f"Building quantized model with config: {quantization_config}")
    base_model_config.pretrained_model_kwargs["quantization_config"] = quantization_config
    model, results_base = build_base_model(base_model_config)
    results.update(results_base)

    return model, results
