import torch

from constants import ATTR_WEIGHT
from quantization.quantized_modules.model_utils import PreNormTransformWrapper, create_quantized_mlp, \
    get_attention_layer



def set_block_quantizers_and_transforms(model, block, block_idx, weight_quantizer_kwargs, act_quantizer_kwargs, qkv_in_transform, o_in_transform, gate_up_in_transform, down_in_transform, v_out_transform):
    quantized_attn = get_attention_layer(model.config)(
        model.config,
        layer_idx=block_idx,
        weight_quantizer_kwargs=weight_quantizer_kwargs,
        act_quantizer_kwargs=act_quantizer_kwargs,
        qkv_in_transform=qkv_in_transform,
        o_in_transform=o_in_transform,
        v_out_transform=v_out_transform,
        norm_gamma=torch.clone(block.input_layernorm.weight.data) if hasattr(block.input_layernorm, ATTR_WEIGHT) else None
    )
    quantized_mlp = create_quantized_mlp(
        model.config,
        layer_idx=block_idx,
        weight_quantizer_kwargs=weight_quantizer_kwargs,
        act_quantizer_kwargs=act_quantizer_kwargs,
        transform1=gate_up_in_transform,
        transform2=down_in_transform,
        norm_gamma=torch.clone(block.post_attention_layernorm.weight.data) if hasattr(block.post_attention_layernorm,
                                                                                      ATTR_WEIGHT) else None
    )
    transformed_input_attn_norm = PreNormTransformWrapper(
        norm=block.input_layernorm,
        in_transform=qkv_in_transform)
    transformed_input_mlp_norm = PreNormTransformWrapper(
        norm=block.post_attention_layernorm,
        in_transform=gate_up_in_transform)

    return quantized_attn, quantized_mlp, transformed_input_attn_norm, transformed_input_mlp_norm