from model_wrapper import make_Act, clear_act_buffer, ActLinear, set_mask, revert_Act_to_Linear
from process_data import get_safe_dataset
from utils import exsqf_init, NFQuantizer
import argparse
import json
import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from peft import TaskType, get_peft_model, LoraConfig
import torch.nn as nn
from safetensors.torch import safe_open
import re


def arg_parse():
    parser = argparse.ArgumentParser(description="Quantize a model with ExSQF.")
    parser.add_argument("--model_name_or_path", type=str, default=None, required=True, help="The name or path of the fp32/16 model.", )
    parser.add_argument("--token", type=str, default=None, help="The access token to download model from HuggingFace Hub.")
    parser.add_argument("--bits", type=int, default=4, help="The quantized bits")
    parser.add_argument("--iter", type=int, default=5, help="The alternating steps in ExSQF")
    parser.add_argument("--rank", type=int, default=32, help="The rank of the LoRA adapter")
    parser.add_argument("--save_dir", type=str, default="./model_zoo/ExSQF/", help="The rank of the LoRA adapter")
    args = parser.parse_args()
    return args


class Shell(nn.Module):
    def __init__(self, weight, bias=None):
        super().__init__()
        self.weight = nn.Parameter(weight, requires_grad=False)
        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False)


def unwrap_model(model, sub_module_name=".base_layer"):
    sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k]
    sub_module_name_set = set(sub_module_name_list)
    for name in sub_module_name_set:
        # get the parent of the submodule
        name_parent = ".".join(name.split(".")[:-1])
        name_child = name.split(".")[-1]
        sub_module = model.get_submodule(name_parent)
        # print(sub_module)

        # replace with shell
        child = getattr(sub_module, name_child)
        weight = getattr(child.base_layer, "weight", None)
        bias = getattr(child.base_layer, "bias", None)
        shell = Shell(weight, bias)

        setattr(sub_module, name_child, shell)

    print("You have unwrapped the model. Use it on your own risk.")


def quantize_and_save():
    args = arg_parse()
    # =============================加载模型和tokenizer==============================
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        token=args.token,
        trust_remote_code=True,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    task_type = TaskType.CAUSAL_LM

    model = make_Act(model, verbose=False)
    model.requires_grad_(False)
    clear_act_buffer(model)

    safe_dataloader = get_safe_dataset(nsamples=128, seed=args.seed, tokenizer=tokenizer)

    for name, module in model.named_modules():
        if isinstance(module, ActLinear):
            module.record_activation = False
            module.clear_act_buffer()

    module_list = ["q_proj", "k_proj", "v_proj", "o_proj"]
    divide_num = len(module_list)
    num_hidden_layers = model.config.num_hidden_layers
    weight_list = {}
    current_num = -1
    for layer in range(num_hidden_layers):
        layer_filter_fn = (lambda x: f"layers.{layer}." in x)

        for name, module in model.named_modules():
            if layer_filter_fn(name) and isinstance(module, ActLinear):
                module.record_activation = True
                module.clear_act_buffer()

        for batch in safe_dataloader:
            inp, tar = batch[0].to("cuda"), batch[1].to("cuda")
            assert True, "should run in disentangle mode"
            mask = tar.ne(-100)
            with set_mask(model, mask):
                model(inp)

        for name, module in model.named_modules():
            if layer_filter_fn(name) and isinstance(module, ActLinear):
                print("Module name:", name)
                f_name = ""
                for t_name in module_list:
                    if t_name in name:
                        f_name = t_name
                        break
                if len(f_name) == 0:
                    print("Wrong name ", name)
                    continue
                current_num += 1
                layer_n = f_name
                module.activation_norms = torch.cat(module.activation_norms, dim=0).to("cuda:0")
                score = module.activation_norms @ module.base.weight.data.T
                U, S, V = torch.svd_lowrank(score.float(), q=args.rank, niter=30)
                V = V.type(module.base.weight.data.dtype)
                weight_list[layer_n + '_' + str(current_num // divide_num) + '_V'] = V
                weight_list[layer_n + '_' + str(current_num // divide_num) + '_lora_C'] = weight_list[layer_n + '_' + str(current_num // divide_num) + '_V'] @ weight_list[layer_n + '_' + str(current_num // divide_num) + '_V'].T
                del weight_list[layer_n + '_' + str(current_num // divide_num) + '_V']
                torch.cuda.empty_cache()

        for name, module in model.named_modules():
            if layer_filter_fn(name) and isinstance(module, ActLinear):
                module.record_activation = False
                module.clear_act_buffer()
    model = revert_Act_to_Linear(model)
    model.zero_grad()
    # print(weight_list.keys())

    other_quantize_modules = ["up_proj", "down_proj", "gate_proj"]
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(key in name for key in other_quantize_modules):
            weight = module.weight.data
            quantizer = NFQuantizer(num_bits=args.bits)
            quantized_weight, max_abs, shape = quantizer.quantize_block(weight)
            dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape)
            module.weight.data = dequantized_weight
            print(f"Quantized {name} | shape: {weight.shape}")

    lora_config = LoraConfig(
        task_type=task_type,
        inference_mode=True,
        r=args.rank,
        lora_alpha=args.rank,
        lora_dropout=0.1,
        target_modules=module_list,
    )

    lora_model = get_peft_model(model, lora_config)

    for name, module in lora_model.named_modules():
        if any(target_key in name for target_key in module_list):
            if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
                weight = module.weight.data
                print(f"Quantizing {name} | shape: {weight.shape}")

                m = re.search(r"layers\.(\d+)\.", name)
                layer_idx = int(m.group(1)) if m else 0
                module_base = name.split('.')[-1]

                if module_base in module_list:
                    lora_C_key = f"{module_base}_{layer_idx}_lora_C"
                    if lora_C_key in weight_list:
                        module.lora_C = weight_list[lora_C_key].clone().to(module.lora_A['default'].weight.dtype)
                        weight_list[lora_C_key] = module.lora_C
                        print(f"Loaded {lora_C_key} for module {name}")
                    else:
                        print(f"Warning: {lora_C_key} not found in weight_list")
                lora_C = getattr(module, "lora_C", None)
                qweight, lora_A, lora_B = exsqf_init(weight, lora_C, num_bits=args.bits, reduced_rank=args.rank, num_iter=args.iter)
                if hasattr(module, 'lora_A'):
                    module.lora_A['default'].weight = torch.nn.Parameter(lora_A)
                if hasattr(module, 'lora_B'):
                    module.lora_B['default'].weight = torch.nn.Parameter(lora_B)
                module.weight.data = qweight

    torch.save(weight_list, 'lora_C.pt')

    model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-rank{args.rank}"
    base_model_dir = os.path.join(args.save_dir, model_name)
    lora_model_dir = os.path.join(args.save_dir, model_name, "exsqf_init")
    lora_model.save_pretrained(lora_model_dir)

    base_model = lora_model.get_base_model()
    unwrap_model(base_model)
    base_model.save_pretrained(base_model_dir)
    tokenizer.save_pretrained(base_model_dir)

    tensors = {}
    with safe_open(os.path.join(lora_model_dir, "adapter_model.safetensors"), framework="pt") as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    torch.save(tensors, os.path.join(lora_model_dir, "adapter_model.bin"))

    with open(os.path.join(lora_model_dir, "adapter_config.json"), "r") as fp:
        adapter_config = json.load(fp)
        adapter_config['base_model_name_or_path'] = base_model_dir
        adapter_config['init_lora_weights'] = True
        fp.close()
    with open(os.path.join(lora_model_dir, "adapter_config.json"), "w") as fp:
        json.dump(adapter_config, fp, indent=2)

    return base_model_dir, lora_model_dir


if __name__ == "__main__":
    base_model_dir, lora_model_dir = quantize_and_save()
