import os 
import sys
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.append(os.path.join(os.getcwd(), "peft/src/"))
from peft import PeftModel, PeftConfig
from peft.tuners.lora import LoraLayer 
import argparse
import torch
from torch import svd_lowrank, nn


def testinit(model, SVDr):
    for module in model.modules():
        if isinstance(module, nn.Linear):
            with torch.no_grad():
                weight = module.weight
                dtype = weight.dtype
                if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
                    raise TypeError(
                        "Please initialize model under float32, float16, or bfloat16. "
                        "Subsequently, re-quantize the residual model to help minimize quantization errors."
                    )
                weight = weight.to(torch.float32)
                Vr, Sr, Ur = svd_lowrank(
                    weight.data, SVDr, niter=4
                )
                scaling = 128 / 128
                Sr /= scaling
                Uhr = Ur.t()
                
                lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
                lora_B = Vr @ torch.diag(torch.sqrt(Sr))
                new_weight = weight.data - scaling * lora_B @ lora_A
                new_weight = weight.to(dtype)
                module.weight.data.copy_(new_weight)
    return model


def manual_merge_and_unload(model, method):
    if method == "lora" or method == "pissa":
        for name, module in model.named_modules():
            if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
                lora_A = module.lora_A.weight.data
                lora_B = module.lora_B.weight.data
                scale = module.lora_alpha / module.r
                delta_W = torch.matmul(lora_B, lora_A) * scale
                module.weight.data += delta_W
                
                del module.lora_A
                del module.lora_B
    elif method == "test":
         for name, module in model.named_modules():
            if hasattr(module, "lora_A_s") and hasattr(module, "lora_A_e") and hasattr(module, "lora_B") and hasattr(module, "lora_E") :
                lora_A_s = module.lora_A_s.weight.data
                lora_A_e = module.lora_A_e.weight.data
                lora_B = module.lora_B.weight.data
                lora_E = torch.diag((module.lora_E).squeeze(dim=1))
                scale = module.lora_alpha / module.r

                transformed_part2 = lora_E @ lora_A_e
                lora_A = torch.cat([lora_A_s, transformed_part2], dim=0)
                delta_W = torch.matmul(lora_B, lora_A) * scale
                module.weight.data += delta_W
                
                del module.lora_A_s
                del module.lora_A_e
                del module.lora_B
                del module.lora_E

    for name, module in model.named_children():
        if isinstance(module, LoraLayer):
            original_layer = nn.Linear(
                module.in_features,
                module.out_features,
                bias=module.bias is not None
            )
            original_layer.load_state_dict(module.base_layer.state_dict())
            setattr(model, name, original_layer)

    if hasattr(model, "peft_config"):
        del model.peft_config
    return model
    

parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
parser.add_argument('--base_model', type=str)
parser.add_argument('--adapter', type=str)
parser.add_argument('--output_path', type=str)
parser.add_argument('--method', default="lora", type=str)
args = parser.parse_args()

assert os.path.exists(args.base_model), f"基础模型路径不存在: {args.base_model}"
assert os.path.exists(os.path.join(args.base_model, "config.json")), f"基础模型缺少 config.json"
assert os.path.exists(args.adapter), f"适配器路径不存在: {args.adapter}"
assert os.path.exists(os.path.join(args.adapter, "adapter_config.json")), f"适配器缺少 adapter_config.json"
assert os.path.exists(os.path.join(args.adapter, "adapter_model.bin")), f"适配器缺少 adapter_model.bin"

model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.base_model)

if args.method == "lora":
    model = PeftModel.from_pretrained(model, args.adapter)
    model = manual_merge_and_unload(model, args.method)
elif args.method == "pissa":
    model = testinit(model, 128)
    model = PeftModel.from_pretrained(model, args.adapter)
    model = manual_merge_and_unload(model, args.method)
elif args.method == "test":
    # model = testinit(model, 32)
    model = PeftModel.from_pretrained(model, args.adapter)
    model = manual_merge_and_unload(model, args.method)
    
# model = model.merge_and_unload()
base_model = model.base_model.model if isinstance(model, PeftModel) else model
base_model.save_pretrained(args.output_path)
tokenizer.save_pretrained(args.output_path)