import copy
from typing import List

import torch
from transformers import AutoModelForCausalLM

import constants
from constants import DEVICE_CPU
from models.model_utils import get_model_layers
from project_utils.helpers import parse_quant_format
from quantization.set_block_utils import set_block_quantizers_and_transforms
from quantization.transforms.transforms_builder import build_R1_learned_transform,  \
    set_block_R1_learned_transforms, set_block_R2_learned_transforms, non_learned_transforms_builder
from run_config import RunConfig
from transform_optimization.opt_config import OptimizationConfig
from transform_optimization.prepare import prepare_model_for_transform_training, build_loss_function, \
    prepare_train_dataloader, wrap_up_training
from transform_optimization.train_transform import train_transform_matrix
from quantization.quant_config import QuantConfig

from quantization.utils.common_utils import clear_device_cache, to, maybe_first_element
from quantization.quantized_modules.model_utils import InputCollector, ForwardInterrupt, \
    load_quantized_modules_state_dict


def build_quantizer_kwargs(config: QuantConfig) -> tuple[dict | None, dict | None]:
    """
    Build weight and activation quantizer kwargs from QuantConfig.

    Args:
        config: QuantConfig with quantization settings

    Returns:
        Tuple of (weight_quantizer_kwargs, act_quantizer_kwargs)
        Returns None for kwargs if format indicates no quantization (16-bit)
    """

    w_bits, w_format = parse_quant_format(config.weight_q_format)
    a_bits, a_format = parse_quant_format(config.activation_q_format)

    weight_quantizer_kwargs = dict(
        bits=w_bits,
        symmetric=config.symmetric,
        format=w_format,
        granularity=config.weight_granularity.value,
        observer=config.observer_type,
        group_size=config.group_size if config.weight_granularity.value == "group" else None,
    )

    act_quantizer_kwargs = dict(
        bits=a_bits,
        symmetric=config.symmetric,
        format=a_format,
        granularity=config.activation_granularity.value,
        observer=config.observer_type,
        group_size=config.group_size if config.activation_granularity.value == "group" else None,
    )

    return weight_quantizer_kwargs, act_quantizer_kwargs


def _fix_block_parametrization(block_to_modules_map, block_idx):
    block_modules = block_to_modules_map[block_idx]

    # Fix model parametrization
    block_modules[constants.MODULE_QUANTIZED_ATTN].fix_parametrization()
    block_modules[constants.MODULE_QUANTIZED_MLP].fix_parametrization()
    # Fix transforms and remove parametrizations
    block_modules[constants.MODULE_QKV_IN_TRANSFORM].remove_parametrizations()
    block_modules[constants.MODULE_O_IN_TRANSFORM].remove_parametrizations()
    if block_modules[constants.MODULE_V_OUT_TRANSFORM]:
        block_modules[constants.MODULE_V_OUT_TRANSFORM].remove_parametrizations()
    block_modules[constants.MODULE_GATE_UP_IN_TRANSFORM].remove_parametrizations()
    block_modules[constants.MODULE_DOWN_IN_TRANSFORM].remove_parametrizations()


def rtn_quantization(
        model: AutoModelForCausalLM,
        calibration_data: List[torch.Tensor],
        quant_config: QuantConfig,
        run_config: RunConfig,
        opt_config: OptimizationConfig,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:

    print("RTN quantization...")
    device = run_config.device
    is_learned_transform = quant_config.transform_class_r1 in ["learned", "learned_affine"] or quant_config.transform_class_r2 in ["learned", "learned_affine"]
    float_model = None
    if is_learned_transform:
        float_model = copy.deepcopy(model).to(DEVICE_CPU)

    # State dict with quantized weights, scales and hadamards
    quantized_state_dict = {}
    non_quantized_state_dict = {}

    # Define common transform kwargs
    transform_kwargs = dict(device=device, group_size=quant_config.hadamard_group_size)

    # Build quantizer kwargs from quant_config
    weight_quantizer_kwargs, act_quantizer_kwargs = build_quantizer_kwargs(quant_config)

    # Get transformer blocks (supports multiple architectures)
    blocks = get_model_layers(model)

    if quant_config.calibrate_activations:
        blocks = get_model_layers(model)
        blocks[0] = InputCollector(blocks[0], cpu_offload=False)

        for sample in calibration_data:
            try:
                with torch.no_grad():
                    model(sample.to(device=device))
            except ForwardInterrupt:
                pass
            
        input_args = blocks[0].input_args
        input_kwargs = blocks[0].input_kwargs
        blocks[0] = blocks[0].module

    shared_R1_learned_transform = None
    head_dim = getattr(model.config, "head_dim", None)
    if is_learned_transform:
        # Prepare transforms for learned transform optimization
        if head_dim is None:
            head_dim = model.config.hidden_size // model.config.num_attention_heads

        assert model.config.hidden_size % quant_config.group_size == 0, "Hidden size must be divisible by group size for learned transforms."
        shared_R1_learned_transform = build_R1_learned_transform(model, opt_config, quant_config, device)

    R2_transforms = {}
    block_to_modules_map = {}
    # Iterate over transformer blocks
    for block_idx, block in enumerate(blocks):
        print(f"Building block {block_idx}...")

        if is_learned_transform:
            qkv_in_transform, gate_up_in_transform, down_in_transform = set_block_R1_learned_transforms(model, shared_R1_learned_transform, quant_config, device)
            o_in_transform, v_out_transform = set_block_R2_learned_transforms(model, quant_config, opt_config, head_dim, device)
        else:
            qkv_in_transform, o_in_transform, gate_up_in_transform, down_in_transform, v_out_transform = (
                non_learned_transforms_builder(model, quant_config, transform_kwargs))

        # 2. Replace blocks with quantized versions
        quantized_attn, quantized_mlp, transformed_input_attn_norm, transformed_input_mlp_norm = (
            set_block_quantizers_and_transforms(model, block, block_idx, weight_quantizer_kwargs, act_quantizer_kwargs,
                                                qkv_in_transform, o_in_transform, gate_up_in_transform,
                                                down_in_transform, v_out_transform))

        block_to_modules_map[block_idx] = {
            constants.MODULE_QUANTIZED_ATTN: quantized_attn,
            constants.MODULE_QUANTIZED_MLP: quantized_mlp,
            constants.MODULE_QKV_IN_TRANSFORM: qkv_in_transform,
            constants.MODULE_O_IN_TRANSFORM: o_in_transform,
            constants.MODULE_GATE_UP_IN_TRANSFORM: gate_up_in_transform,
            constants.MODULE_DOWN_IN_TRANSFORM: down_in_transform,
            constants.MODULE_V_OUT_TRANSFORM: v_out_transform
        }

        load_quantized_modules_state_dict(
            block,
            quantized_attn,
            quantized_mlp,
            transformed_input_attn_norm,
            transformed_input_mlp_norm,
            model.config
        )

        # Move to original device and dtype
        _ = block.to(device=device, dtype=model.config.torch_dtype)


    #########################################
    # Train Transform
    #########################################

    if is_learned_transform:

        prepare_model_for_transform_training(model, shared_R1_learned_transform, R2_transforms)

        assert float_model is not None, "Float model must be provided for learned transform optimization."

        loss_fn = build_loss_function(model, float_model, opt_config)

        train_loader = prepare_train_dataloader(calibration_data, run_config)

        train_transform_matrix(model=model,
                               shared_learned_transforms=[shared_R1_learned_transform, *list(R2_transforms.values())],
                               opt_config=opt_config,
                               dataloader=train_loader,
                               loss_fn=loss_fn,
                               device=device)

        wrap_up_training(shared_R1_learned_transform, R2_transforms)

    #########################################

    for block_idx, block in enumerate(blocks):
        print(f"Processing block {block_idx}...")

        # Calibrate activations (if needed)
        if quant_config.calibrate_activations:
            device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
            for inp_args, inp_kwargs in zip(input_args, input_kwargs):
                with torch.no_grad():
                    block(*to(inp_args, device=device), **to(inp_kwargs, device=device))


        _fix_block_parametrization(block_to_modules_map, block_idx)

        if quant_config.calibrate_activations:
            for inp_args, inp_kwargs in zip(input_args, input_kwargs):
                with torch.no_grad():
                    out = block(*to(inp_args, device=device), **to(inp_kwargs, device=device))
                out = maybe_first_element(out).to(device)
                # change only first input argument
                if len(inp_args) > 0:
                    inp_args[0].data = out
                elif "hidden_states" in inp_kwargs:
                    inp_kwargs["hidden_states"] = out
                else:
                    raise ValueError("Unsupported block input format.")


        clear_device_cache(garbage_collection=True)

    clear_device_cache(garbage_collection=True)

    return quantized_state_dict, non_quantized_state_dict
