# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
# Licensed under Apache License 2.0.

import transformers

from train_utils import apply_r3_r4, rtn_utils
from spin_utils import fuse_norm_utils, hadamard_utils, quant_utils, utils


def prepare_model(args, model):
    transformers.set_seed(args.seed)
    model.eval()

    # Rotate the weights
    fuse_norm_utils.fuse_layer_norms(model)
    apply_r3_r4.rotate_model(model, args)
    utils.cleanup_memory(verbos=True)

    quant_utils.add_actquant(model)  # Add Activation Wrapper to the model
    qlayers = quant_utils.find_qlayers(model)
    for name in qlayers:
        if "down_proj" in name:
            if not args.block_rotation:
                had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
                qlayers[name].online_full_had = True
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].fp32_had = args.fp32_had
            else:
                had_K, K = hadamard_utils.get_hadK(args.block_size_linear)
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].online_block_had = True
                qlayers[name].had_dim = args.block_size_linear
                qlayers[name].fp32_had = args.fp32_had

    if args.w_bits < 16:
        quantizers = rtn_utils.rtn_fwrd(model, "cuda", args)

    # Add Input Quantization
    if args.a_bits < 16 or args.v_bits < 16:
        qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
        down_proj_groupsize = -1
        if args.a_groupsize > 0:
            down_proj_groupsize = utils.llama_down_proj_groupsize(
                model, args.a_groupsize
            )

        for name in qlayers:
            layer_input_bits = args.a_bits
            layer_groupsize = args.a_groupsize
            layer_a_sym = not (args.a_asym)
            layer_a_clip = args.a_clip_ratio

            num_heads = model.config.num_attention_heads
            model_dim = model.config.hidden_size
            head_dim = model_dim // num_heads

            if "v_proj" in name and args.v_bits < 16:  # Set the v_proj precision
                if args.mxfp4:
                    print("MXFP4 quantization for v_proj is not supported yet!")
                    pass
                else:
                    v_groupsize = head_dim
                    qlayers[name].out_quantizer.configure(
                        bits=args.v_bits,
                        groupsize=v_groupsize,
                        sym=not (args.v_asym),
                        clip_ratio=args.v_clip_ratio,
                    )

            if "o_proj" in name:
                layer_groupsize = head_dim

            if "lm_head" in name:  # Skip lm_head quantization
                layer_input_bits = 16

            if "down_proj" in name:  # Set the down_proj precision
                if args.int8_down_proj:
                    layer_input_bits = 8
                layer_groupsize = down_proj_groupsize
            qlayers[name].quantizer.configure(
                bits=layer_input_bits,
                groupsize=layer_groupsize,
                sym=layer_a_sym,
                clip_ratio=layer_a_clip,
                mxfp4=args.mxfp4,
                args=args if args.mxfp4 else None,
            )

    if args.k_bits < 16 and not args.mxfp4:
        if args.k_pre_rope:
            raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
        else:
            rope_function_name = "apply_rotary_pos_emb"
            layers = model.model.layers
            k_quant_config = {
                "k_bits": args.k_bits,
                "k_groupsize": args.k_groupsize,
                "k_sym": not (args.k_asym),
                "k_clip_ratio": args.k_clip_ratio,
            }
            for layer in layers:
                apply_r3_r4.add_qk_rotation_wrapper_after_function_call_in_forward(
                    layer.self_attn,
                    rope_function_name,
                    config=model.config,
                    **k_quant_config,
                )

    return model
